diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index 6cdef7357..cd9348a80 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -55,7 +55,13 @@ runs: run: | uv pip install --system -r examples/applications/requirements_applications.txt uv pip install --system -r examples/ray_compat/requirements.txt + readarray -t skip_examples < examples/skip_examples.txt for example in "./examples"/*.py; do + filename=$(basename "$example") + if [[ " ${skip_examples[*]} " =~ [[:space:]]${filename}[[:space:]] ]]; then + echo "Skipping $example" + continue + fi echo "Running $example" python $example done diff --git a/.gitignore b/.gitignore index 739abe778..a1c7cc9a9 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,9 @@ CMakeFiles/ src/scaler/protocol/capnp/*.c++ src/scaler/protocol/capnp/*.h +orb/logs/ +orb/metrics/ + # AWS HPC test-generated files .scaler_aws_batch_config.json .scaler_aws_hpc.env diff --git a/README.md b/README.md index 4d89379c4..2b6f332e5 100644 --- a/README.md +++ b/README.md @@ -279,6 +279,7 @@ The following table maps each Scaler command to its corresponding section name i | `scaler_worker_manager symphony` | `[[worker_manager]]` + `type = "symphony"` | | `scaler_worker_manager aws_raw_ecs` | `[[worker_manager]]` + `type = "aws_raw_ecs"` | | `scaler_worker_manager aws_hpc` | `[[worker_manager]]` + `type = "aws_hpc"` | +| `scaler_worker_manager orb_aws_ec2` | `[[worker_manager]]` + `type = "orb_aws_ec2"` | ### Practical Scenarios & Examples @@ -507,6 +508,40 @@ where `deepest_nesting_level` is the deepest nesting level a task has in your wo workload that has a base task that calls a nested task that calls another nested task, then the deepest nesting level is 2. +## ORB AWS EC2 integration + +A Scaler scheduler can interface with ORB (Open Resource Broker) to dynamically provision and manage workers on AWS EC2 instances. + +```bash +$ scaler_worker_manager orb_aws_ec2 tcp://127.0.0.1:2345 --image-id ami-0528819f94f4f5fa5 +``` + +This will start an ORB AWS EC2 worker adapter that connects to the Scaler scheduler at `tcp://127.0.0.1:2345`. The scheduler can then request new workers from this adapter, which will be launched as EC2 instances. + +The ORB AWS EC2 worker manager can also be included in a `scaler` all-in-one TOML config: + +```toml +[scheduler] +scheduler_address = "tcp://127.0.0.1:2345" + +[[worker_manager]] +type = "orb_aws_ec2" +scheduler_address = "tcp://127.0.0.1:2345" +image_id = "ami-0528819f94f4f5fa5" +instance_type = "t3.medium" +aws_region = "us-east-1" +``` + +### Configuration + +The ORB AWS EC2 adapter requires `orb-py` and `boto3` to be installed. You can install them with: + +```bash +$ pip install "opengris-scaler[orb_aws_ec2]" +``` + +For more details on configuring ORB AWS EC2, including AWS credentials and instance templates, please refer to the [ORB AWS EC2 Worker Adapter documentation](https://finos.github.io/opengris-scaler/tutorials/worker_manager_adapter/orb_aws_ec2.html). + ## Worker Manager usage > **Note**: This feature is experimental and may change in future releases. diff --git a/docs/source/tutorials/commands.rst b/docs/source/tutorials/commands.rst index f8017a334..7f3e1915e 100644 --- a/docs/source/tutorials/commands.rst +++ b/docs/source/tutorials/commands.rst @@ -14,7 +14,7 @@ After installing ``opengris-scaler``, the following CLI commands are available f * - :ref:`scaler_scheduler ` - Start only the scheduler process (and auto-start object storage when needed). * - :ref:`scaler_worker_manager ` - - Start one worker manager using a subcommand (``baremetal_native``, ``symphony``, ``aws_raw_ecs``, ``aws_hpc``). + - Start one worker manager using a subcommand (``baremetal_native``, ``symphony``, ``aws_raw_ecs``, ``aws_hpc``, ``orb_aws_ec2``). * - :ref:`scaler_object_storage_server ` - Start only the object storage server. * - :ref:`scaler_top ` @@ -53,6 +53,8 @@ All commands support ``--config``/``-c``. In practice, most deployments use TOML - ``[[worker_manager]]`` + ``type = "aws_raw_ecs"`` * - ``scaler_worker_manager aws_hpc`` - ``[[worker_manager]]`` + ``type = "aws_hpc"`` + * - ``scaler_worker_manager orb_aws_ec2`` + - ``[[worker_manager]]`` + ``type = "orb_aws_ec2"`` .. _cmd-scaler: @@ -352,6 +354,7 @@ Available subcommands: - ``symphony`` - ``aws_raw_ecs`` - ``aws_hpc`` +- ``orb_aws_ec2`` When ``--config``/``-c`` is supplied, ``scaler_worker_manager`` reads the ``[[worker_manager]]`` array from the TOML file and picks the entry whose ``type`` field matches the subcommand. @@ -753,6 +756,79 @@ AWS Batch worker manager. - ``60`` - Timeout for each submitted job. +Subcommand: ``orb_aws_ec2`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +ORB (Open Resource Broker) worker manager — dynamically provisions workers on AWS EC2 instances. + +.. code-block:: bash + + $ scaler_worker_manager orb_aws_ec2 [options] + +.. tabs:: + + .. group-tab:: command line + + .. code-block:: bash + + $ scaler_worker_manager orb_aws_ec2 tcp://127.0.0.1:6378 \ + --object-storage-address tcp://127.0.0.1:6379 \ + --image-id ami-0528819f94f4f5fa5 \ + --instance-type t3.medium \ + --aws-region us-east-1 + + .. group-tab:: config.toml + + .. code-block:: toml + + [[worker_manager]] + type = "orb_aws_ec2" + scheduler_address = "tcp://127.0.0.1:6378" + object_storage_address = "tcp://127.0.0.1:6379" + image_id = "ami-0528819f94f4f5fa5" + instance_type = "t3.medium" + aws_region = "us-east-1" + + Run command: + + .. code-block:: bash + + $ scaler config.toml + +.. list-table:: + :header-rows: 1 + + * - Argument + - Required + - Default + - Description + * - ``--image-id`` + - Yes + - - + - AMI ID for the worker EC2 instances. + * - ``--instance-type`` + - No + - ``t2.micro`` + - EC2 instance type. + * - ``--aws-region`` + - No + - ``us-east-1`` + - AWS region. + * - ``--key-name`` + - No + - ``None`` + - AWS key pair name. A temporary key pair is created if omitted. + * - ``--subnet-id`` + - No + - ``None`` + - AWS subnet ID. Defaults to the default subnet in the default VPC. + * - ``--security-group-ids`` + - No + - ``[]`` + - Comma-separated AWS security group IDs. A temporary group is created if omitted. + +For full details, see :doc:`worker_managers/orb_aws_ec2`. + .. _cmd-scaler-object-storage-server: diff --git a/docs/source/tutorials/compatibility/ray.rst b/docs/source/tutorials/compatibility/ray.rst index d80641b24..83eb94e06 100644 --- a/docs/source/tutorials/compatibility/ray.rst +++ b/docs/source/tutorials/compatibility/ray.rst @@ -6,7 +6,7 @@ Ray Scaler is a lightweight distributed computation engine similar to Ray. Scaler supports many of the same concepts as Ray including remote functions (known as tasks in Scaler), futures, cluster object storage, labels (known as capabilities in Scaler), and it comes with comparable monitoring tools. -Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including AWS EC2 and IBM Symphony, +Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including ORB (AWS EC2) and IBM Symphony, with more providers planned for the future. You can view our `roadmap on GitHub `_ for details on upcoming cloud integrations. diff --git a/docs/source/tutorials/worker_managers/index.rst b/docs/source/tutorials/worker_managers/index.rst index a6b25f411..80185c292 100644 --- a/docs/source/tutorials/worker_managers/index.rst +++ b/docs/source/tutorials/worker_managers/index.rst @@ -54,6 +54,10 @@ Worker Managers Overview - Offloads tasks to IBM Spectrum Symphony via the SOAM API. - Concurrency-limited - IBM Symphony + * - :doc:`ORB AWS EC2 ` + - Dynamically provisions workers on AWS EC2 instances using the ORB system. + - Dynamic (scheduler-driven) + - AWS EC2 Although worker managers target different infrastructures, many configuration options are shared. See :doc:`Common Worker Manager Parameters ` for these shared settings. @@ -72,4 +76,5 @@ The :ref:`scaler ` command boots the full stack from a single TOML c aws_hpc_batch aws_raw_ecs symphony + orb_aws_ec2 common_parameters diff --git a/docs/source/tutorials/worker_managers/orb_aws_ec2.rst b/docs/source/tutorials/worker_managers/orb_aws_ec2.rst new file mode 100644 index 000000000..50ea464e1 --- /dev/null +++ b/docs/source/tutorials/worker_managers/orb_aws_ec2.rst @@ -0,0 +1,145 @@ +ORB AWS EC2 Worker Adapter +========================== + +The ORB AWS EC2 worker adapter allows Scaler to dynamically provision workers on AWS EC2 instances using the ORB (Open Resource Broker) system. This is particularly useful for scaling workloads that require significant compute resources or specialized hardware available in the cloud. + +This tutorial describes the steps required to get up and running with the ORB AWS EC2 adapter. + +Requirements +------------ + +Before using the ORB AWS EC2 worker adapter, ensure the following requirements are met on the machine that will run the adapter: + +1. **orb-py and boto3**: The ``orb-py`` and ``boto3`` packages must be installed. These can be installed using the ``orb_aws_ec2`` optional dependency of Scaler: + + .. code-block:: bash + + pip install "opengris-scaler[orb_aws_ec2]" + +2. **AWS CLI**: The AWS Command Line Interface must be installed and configured with a default profile that has permissions to launch, describe, and terminate EC2 instances. + +3. **Network Connectivity**: The adapter must be able to communicate with AWS APIs and the Scaler scheduler. + +Getting Started +--------------- + +To start the ORB AWS EC2 worker adapter, use the ``scaler_worker_manager orb_aws_ec2`` subcommand: + +.. code-block:: bash + + scaler_worker_manager orb_aws_ec2 tcp://:8516 \ + --object-storage-address tcp://:8517 \ + --image-id ami-0528819f94f4f5fa5 \ + --instance-type t3.medium \ + --aws-region us-east-1 \ + --logging-level INFO \ + --task-timeout-seconds 60 + +Equivalent configuration using a TOML file with ``scaler``: + +.. code-block:: toml + + # stack.toml + + [scheduler] + scheduler_address = "tcp://:8516" + + [[worker_manager]] + type = "orb_aws_ec2" + scheduler_address = "tcp://:8516" + object_storage_address = "tcp://:8517" + image_id = "ami-0528819f94f4f5fa5" + instance_type = "t3.medium" + aws_region = "us-east-1" + logging_level = "INFO" + task_timeout_seconds = 60 + +.. code-block:: bash + + scaler stack.toml + +* ``tcp://:8516`` is the address workers will use to connect to the scheduler. +* ``tcp://:8517`` is the address workers will use to connect to the object storage server. +* New workers will be launched using the specified AMI and instance type. + +Networking Configuration +------------------------ + +Workers launched by the ORB AWS EC2 adapter are EC2 instances and require an externally-reachable IP address for the scheduler. + +* **Internal Communication**: If the machine running the scheduler is another EC2 instance in the same VPC, you can use EC2 private IP addresses. +* **Public Internet**: If communicating over the public internet, it is highly recommended to set up robust security rules and/or a VPN to protect the cluster. + +Publicly Available AMIs +----------------------- + +We regularly publish publicly available Amazon Machine Images (AMIs) with Python and ``opengris-scaler`` pre-installed. + +.. list-table:: Available Public AMIs + :widths: 15 15 20 20 30 + :header-rows: 1 + + * - Scaler Version + - Python Version + - Amazon Linux 2023 Version + - Date (MM/DD/YYYY) + - AMI ID (us-east-1) + * - 1.14.2 + - 3.13 + - 2023.10.20260120 + - 01/30/2026 + - ``ami-0528819f94f4f5fa5`` + * - 1.15.0 + - 3.13 + - 2023.10.20260302.1 + - 03/16/2026 + - ``ami-044265172bea55d51`` + * - 1.26.4 + - 3.13 + - 2023.10.20260302.1 + - 03/26/2026 + - ``ami-0b76605999d8f5d2b`` + +New AMIs will be added to this list as they become available. + +Supported Parameters +-------------------- + +.. note:: + For more details on how to configure Scaler, see the :doc:`../configuration` section. + +The ORB AWS EC2 worker adapter supports ORB-specific configuration parameters as well as common worker adapter parameters. + +ORB AWS EC2 Template Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* ``--image-id`` (Required): AMI ID for the worker instances. +* ``--instance-type``: EC2 instance type (default: ``t2.micro``). +* ``--aws-region``: AWS region (default: ``us-east-1``). +* ``--key-name``: AWS key pair name for the instances. If not provided, a temporary key pair will be created and deleted on cleanup. +* ``--subnet-id``: AWS subnet ID where the instances will be launched. If not provided, it attempts to discover the default subnet in the default VPC. +* ``--security-group-ids``: Comma-separated list of AWS security group IDs. +* ``--allowed-ip``: IP address to allow in the security group (if created automatically). Defaults to the adapter's external IP. +* ``--orb-config-path``: Path to the ORB root directory (default: ``src/scaler/drivers/orb``). + +Common Parameters +~~~~~~~~~~~~~~~~~ + +For a full list of common parameters including networking, worker configuration, and logging, see :doc:`common_parameters`. + +Cleanup +------- + +The ORB AWS EC2 worker adapter is designed to be self-cleaning, but it is important to be aware of the resources it manages: + +* **Key Pairs**: If a ``--key-name`` is not provided, the adapter creates a temporary AWS key pair. +* **Security Groups**: If ``--security-group-ids`` are not provided, the adapter creates a temporary security group to allow communication. +* **Launch Templates**: ORB may additionally create EC2 Launch Templates as part of the machine provisioning process. + +The adapter attempts to delete these temporary resources and terminate all launched EC2 instances when it shuts down gracefully. However, in the event of an ungraceful crash or network failure, some resources may persist in your AWS account. + +.. tip:: + It is recommended to periodically check your AWS console for any orphaned resources (instances, security groups, key pairs, or launch templates) and clean them up manually if necessary to avoid unexpected costs. + +.. warning:: + **Subnet and Security Groups**: Currently, specifying ``--subnet-id`` or ``--security-group-ids`` via configuration might not have the intended effect as the adapter is designed to auto-discover or create these resources. Specifically, the adapter may still attempt to use default subnets or create its own temporary security groups regardless of these parameters. diff --git a/pyproject.toml b/pyproject.toml index 7f0ec9ea8..6fe2a7c66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,10 +50,15 @@ graphblas = [ aws = [ "boto3", ] +orb_aws_ec2 = [ + "orb-py~=1.5.1; python_version >= '3.10'", + "boto3; python_version >= '3.10'", +] all = [ "opengris-scaler[aws]", "opengris-scaler[graphblas]", "opengris-scaler[gui]", + "opengris-scaler[orb_aws_ec2]", "opengris-scaler[uvloop]", ] diff --git a/src/scaler/config/defaults.py b/src/scaler/config/defaults.py index f84c076be..d3ecf150f 100644 --- a/src/scaler/config/defaults.py +++ b/src/scaler/config/defaults.py @@ -56,7 +56,7 @@ # WORKER SPECIFIC OPTIONS # number of workers, echo worker use 1 process -DEFAULT_MAX_TASK_CONCURRENCY = os.cpu_count() - 1 +DEFAULT_MAX_TASK_CONCURRENCY = os.cpu_count() # number of seconds that worker agent send heartbeat to scheduler DEFAULT_HEARTBEAT_INTERVAL_SECONDS = 2 diff --git a/src/scaler/config/section/orb_aws_ec2_worker_adapter.py b/src/scaler/config/section/orb_aws_ec2_worker_adapter.py new file mode 100644 index 000000000..5fbcd9630 --- /dev/null +++ b/src/scaler/config/section/orb_aws_ec2_worker_adapter.py @@ -0,0 +1,37 @@ +import dataclasses +from typing import ClassVar, List, Optional + +from scaler.config.common.logging import LoggingConfig +from scaler.config.common.worker import WorkerConfig +from scaler.config.common.worker_manager import WorkerManagerConfig +from scaler.config.config_class import ConfigClass + + +@dataclasses.dataclass +class ORBAWSEC2WorkerAdapterConfig(ConfigClass): + """Configuration for the ORB AWS EC2 worker adapter.""" + + _tag: ClassVar[str] = "orb_aws_ec2" + + worker_manager_config: WorkerManagerConfig + + # ORB AWS EC2 Template configuration + image_id: str = dataclasses.field(metadata=dict(help="AMI ID for the worker instances", required=True)) + key_name: Optional[str] = dataclasses.field( + default=None, metadata=dict(help="AWS key pair name for the instances (optional)") + ) + subnet_id: Optional[str] = dataclasses.field( + default=None, metadata=dict(help="AWS subnet ID where the instances will be launched (optional)") + ) + + worker_config: WorkerConfig = dataclasses.field(default_factory=WorkerConfig) + logging_config: LoggingConfig = dataclasses.field(default_factory=LoggingConfig) + + instance_type: str = dataclasses.field(default="t2.micro", metadata=dict(help="EC2 instance type")) + aws_region: Optional[str] = dataclasses.field(default="us-east-1", metadata=dict(help="AWS region")) + security_group_ids: List[str] = dataclasses.field( + default_factory=list, + metadata=dict( + type=lambda s: [x for x in s.split(",") if x], help="Comma-separated list of AWS security group IDs" + ), + ) diff --git a/src/scaler/config/section/worker_manager_union.py b/src/scaler/config/section/worker_manager_union.py index 8bd3565cd..b98248e82 100644 --- a/src/scaler/config/section/worker_manager_union.py +++ b/src/scaler/config/section/worker_manager_union.py @@ -3,8 +3,13 @@ from scaler.config.section.aws_hpc_worker_manager import AWSBatchWorkerManagerConfig from scaler.config.section.ecs_worker_manager import ECSWorkerManagerConfig from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig +from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig from scaler.config.section.symphony_worker_manager import SymphonyWorkerManagerConfig WorkerManagerUnion = Union[ - NativeWorkerManagerConfig, SymphonyWorkerManagerConfig, ECSWorkerManagerConfig, AWSBatchWorkerManagerConfig + NativeWorkerManagerConfig, + SymphonyWorkerManagerConfig, + ECSWorkerManagerConfig, + AWSBatchWorkerManagerConfig, + ORBAWSEC2WorkerAdapterConfig, ] diff --git a/src/scaler/entry_points/scaler.py b/src/scaler/entry_points/scaler.py index 20702d8c5..729a53e73 100644 --- a/src/scaler/entry_points/scaler.py +++ b/src/scaler/entry_points/scaler.py @@ -9,6 +9,7 @@ from scaler.config.section.ecs_worker_manager import ECSWorkerManagerConfig from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig from scaler.config.section.object_storage_server import ObjectStorageServerConfig +from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig from scaler.config.section.scheduler import SchedulerConfig from scaler.config.section.symphony_worker_manager import SymphonyWorkerManagerConfig from scaler.config.section.webgui import WebGUIConfig @@ -59,6 +60,10 @@ def _run_worker_manager(config: WorkerManagerUnion) -> None: from scaler.worker_manager_adapter.aws_hpc.worker_manager import AWSHPCWorkerManager AWSHPCWorkerManager(config).run() + elif isinstance(config, ORBAWSEC2WorkerAdapterConfig): + from scaler.worker_manager_adapter.orb_aws_ec2.worker_manager import ORBAWSEC2WorkerAdapter + + ORBAWSEC2WorkerAdapter(config).run() def _run_gui(config: WebGUIConfig) -> None: diff --git a/src/scaler/entry_points/worker_manager.py b/src/scaler/entry_points/worker_manager.py index 0977a6e6c..413f31599 100644 --- a/src/scaler/entry_points/worker_manager.py +++ b/src/scaler/entry_points/worker_manager.py @@ -7,12 +7,17 @@ from scaler.config.section.aws_hpc_worker_manager import AWSBatchWorkerManagerConfig from scaler.config.section.ecs_worker_manager import ECSWorkerManagerConfig from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig +from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig from scaler.config.section.symphony_worker_manager import SymphonyWorkerManagerConfig from scaler.utility.event_loop import register_event_loop from scaler.utility.logging.utility import setup_logger _AnyWorkerManagerConfig = Union[ - NativeWorkerManagerConfig, SymphonyWorkerManagerConfig, ECSWorkerManagerConfig, AWSBatchWorkerManagerConfig + NativeWorkerManagerConfig, + SymphonyWorkerManagerConfig, + ECSWorkerManagerConfig, + AWSBatchWorkerManagerConfig, + ORBAWSEC2WorkerAdapterConfig, ] _TYPE_MAP: Dict[str, Type[ConfigClass]] = { @@ -20,6 +25,7 @@ SymphonyWorkerManagerConfig._tag: SymphonyWorkerManagerConfig, ECSWorkerManagerConfig._tag: ECSWorkerManagerConfig, AWSBatchWorkerManagerConfig._tag: AWSBatchWorkerManagerConfig, + ORBAWSEC2WorkerAdapterConfig._tag: ORBAWSEC2WorkerAdapterConfig, } @@ -82,6 +88,10 @@ def main() -> None: from scaler.worker_manager_adapter.aws_hpc.worker_manager import AWSHPCWorkerManager AWSHPCWorkerManager(wm_config).run() + elif isinstance(wm_config, ORBAWSEC2WorkerAdapterConfig): + from scaler.worker_manager_adapter.orb_aws_ec2.worker_manager import ORBAWSEC2WorkerAdapter + + ORBAWSEC2WorkerAdapter(wm_config).run() if __name__ == "__main__": diff --git a/src/scaler/scheduler/controllers/policies/simple_policy/scaling/capability_scaling.py b/src/scaler/scheduler/controllers/policies/simple_policy/scaling/capability_scaling.py index a8278a84a..dbbb8e5d0 100644 --- a/src/scaler/scheduler/controllers/policies/simple_policy/scaling/capability_scaling.py +++ b/src/scaler/scheduler/controllers/policies/simple_policy/scaling/capability_scaling.py @@ -216,7 +216,8 @@ def _create_start_command( worker_manager_heartbeat: WorkerManagerHeartbeat, ) -> Optional[WorkerManagerCommand]: """Create a start workers command if capacity allows.""" - if len(managed_worker_ids) >= worker_manager_heartbeat.max_task_concurrency: + max_concurrency = worker_manager_heartbeat.max_task_concurrency + if max_concurrency != -1 and len(managed_worker_ids) >= max_concurrency: return None logging.info(f"Requesting worker with capabilities: {capability_dict!r}") diff --git a/src/scaler/scheduler/controllers/policies/simple_policy/scaling/vanilla.py b/src/scaler/scheduler/controllers/policies/simple_policy/scaling/vanilla.py index 295a4781a..63ca395dd 100644 --- a/src/scaler/scheduler/controllers/policies/simple_policy/scaling/vanilla.py +++ b/src/scaler/scheduler/controllers/policies/simple_policy/scaling/vanilla.py @@ -49,7 +49,8 @@ def get_status(self, managed_workers: Dict[bytes, List[WorkerID]]) -> ScalingMan def _create_start_commands( self, managed_worker_ids: List[WorkerID], worker_manager_heartbeat: WorkerManagerHeartbeat ) -> List[WorkerManagerCommand]: - if len(managed_worker_ids) >= worker_manager_heartbeat.max_task_concurrency: + max_concurrency = worker_manager_heartbeat.max_task_concurrency + if max_concurrency != -1 and len(managed_worker_ids) >= max_concurrency: return [] return [WorkerManagerCommand.new_msg(worker_ids=[], command=WorkerManagerCommandType.StartWorkers)] diff --git a/src/scaler/scheduler/controllers/worker_manager_controller.py b/src/scaler/scheduler/controllers/worker_manager_controller.py index 48785ddf0..3c985787a 100644 --- a/src/scaler/scheduler/controllers/worker_manager_controller.py +++ b/src/scaler/scheduler/controllers/worker_manager_controller.py @@ -42,6 +42,12 @@ def __init__(self, config_controller: VanillaConfigController, policy_controller # Reverse map: worker_manager_id -> source (for duplicate detection) self._manager_id_to_source: Dict[bytes, bytes] = {} + # Sources that have reported TooManyWorkers: maps source -> worker count at the time + # TooManyWorkers was received. Suppress new StartWorkers until the scheduler's view of + # managed workers grows beyond that baseline, meaning at least one booting instance has + # sent its first heartbeat and the ORB adapter slot is no longer occupied by a pending boot. + self._at_capacity_baseline: Dict[bytes, int] = {} + def register(self, binder: AsyncBinder, task_controller: TaskController, worker_controller: WorkerController): self._binder = binder self._task_controller = task_controller @@ -74,6 +80,22 @@ async def on_heartbeat(self, source: bytes, heartbeat: WorkerManagerHeartbeat): # Build cross-manager snapshots from all known managers worker_manager_snapshots = self._build_manager_snapshots() + # Wait for the previous command to complete before sending another. + # Worker managers can take a long time to fulfill commands (e.g. ORB polls for instance IDs), + # so sending a new command before the response arrives causes duplicate work and errors. + if source in self._pending_commands: + return + + # If this manager previously reported TooManyWorkers, suppress new StartWorkers requests + # until the scheduler's worker count grows beyond the baseline recorded at that time. + # This handles the visibility gap where the ORB adapter has created instances that have + # not yet sent their first heartbeat to the scheduler. + if source in self._at_capacity_baseline: + if len(managed_worker_ids) > self._at_capacity_baseline[source]: + del self._at_capacity_baseline[source] + else: + return + commands = self._policy_controller.get_scaling_commands( information_snapshot, heartbeat, managed_worker_ids, managed_worker_capabilities, worker_manager_snapshots ) @@ -93,6 +115,12 @@ async def on_command_response(self, source: bytes, response: WorkerManagerComman self._manager_capabilities[source] = dict(response.capabilities) else: logging.warning(f"StartWorkers failed: {response.status.name}") + if response.status == WorkerManagerCommandResponse.Status.TooManyWorkers: + manager_entry = self._manager_alive_since.get(source) + if manager_entry is not None: + _, hb = manager_entry + baseline = len(self._worker_controller.get_workers_by_manager_id(hb.worker_manager_id)) + self._at_capacity_baseline[source] = baseline elif response.command == WorkerManagerCommandType.ShutdownWorkers: if response.status != WorkerManagerCommandResponse.Status.Success: @@ -181,3 +209,4 @@ async def _disconnect_manager(self, source: bytes): self._manager_alive_since.pop(source) self._pending_commands.pop(source, None) self._manager_capabilities.pop(source, None) + self._at_capacity_baseline.pop(source, None) diff --git a/src/scaler/utility/dict_utils.py b/src/scaler/utility/dict_utils.py new file mode 100644 index 000000000..97a0b8429 --- /dev/null +++ b/src/scaler/utility/dict_utils.py @@ -0,0 +1,38 @@ +import re +from typing import Any + + +def to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def to_snake_case(camel_str: str) -> str: + pattern = re.compile(r"(? Any: + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + new_key = to_camel_case(k) if isinstance(k, str) else k + new_d[new_key] = camelcase_dict(v) + return new_d + elif isinstance(d, list): + return [camelcase_dict(i) for i in d] + else: + return d + + +def snakecase_dict(d: Any) -> Any: + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + new_key = to_snake_case(k) if isinstance(k, str) else k + new_d[new_key] = snakecase_dict(v) + return new_d + elif isinstance(d, list): + return [snakecase_dict(i) for i in d] + else: + return d diff --git a/src/scaler/worker/worker.py b/src/scaler/worker/worker.py index e420c1b17..204145254 100644 --- a/src/scaler/worker/worker.py +++ b/src/scaler/worker/worker.py @@ -65,6 +65,7 @@ def __init__( logging_paths: Tuple[str, ...], logging_level: str, worker_manager_id: bytes, + deterministic_worker_ids: bool = False, ): multiprocessing.Process.__init__(self, name="Agent") @@ -77,7 +78,10 @@ def __init__( self._io_threads = io_threads self._task_queue_size = task_queue_size - self._ident = WorkerID.generate_worker_id(name) # _identity is internal to multiprocessing.Process + if deterministic_worker_ids: + self._ident = WorkerID(name.encode()) + else: + self._ident = WorkerID.generate_worker_id(name) self._address_path_internal = os.path.join(tempfile.gettempdir(), f"scaler_worker_{uuid.uuid4().hex}") self._address_internal = ZMQConfig(ZMQType.ipc, host=self._address_path_internal) diff --git a/src/scaler/worker_manager_adapter/orb_aws_ec2/__init__.py b/src/scaler/worker_manager_adapter/orb_aws_ec2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scaler/worker_manager_adapter/orb_aws_ec2/ami/build.sh b/src/scaler/worker_manager_adapter/orb_aws_ec2/ami/build.sh new file mode 100755 index 000000000..dc211f116 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb_aws_ec2/ami/build.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -e +set -x + +# This script builds the AMI for opengris-scaler using Packer +# It reads the version from the version.txt file and passes it as a variable + +# Get the directory where the script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VERSION_FILE="$SCRIPT_DIR/../../../version.txt" + +if [ ! -f "$VERSION_FILE" ]; then + echo "Error: Version file not found at $VERSION_FILE" + exit 1 +fi + +VERSION=$(cat "$VERSION_FILE" | tr -d '[:space:]') + +echo "Building AMI for version: $VERSION" + +cd "$SCRIPT_DIR" +packer build -var "version=$VERSION" opengris-scaler.pkr.hcl diff --git a/src/scaler/worker_manager_adapter/orb_aws_ec2/ami/opengris-scaler.pkr.hcl b/src/scaler/worker_manager_adapter/orb_aws_ec2/ami/opengris-scaler.pkr.hcl new file mode 100644 index 000000000..8611a939f --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb_aws_ec2/ami/opengris-scaler.pkr.hcl @@ -0,0 +1,69 @@ +packer { + required_plugins { + amazon = { + version = "~> 1" + source = "github.com/hashicorp/amazon" + } + } +} + +variable "aws_region" { + type = string + default = "us-east-1" +} + +variable "version" { + type = string +} + +variable "ami_regions" { + type = list(string) + default = [] + description = "A list of regions to copy the AMI to." +} + +variable "ami_groups" { + type = list(string) + default = ["all"] + description = "A list of groups to share the AMI with. Set to ['all'] to make public." +} + +variable "python_version" { + type = string + default = "3.13" +} + +source "amazon-ebs" "opengris-scaler" { + ami_name = "opengris-scaler-${var.version}-py${var.python_version}" + instance_type = "t2.small" + region = var.aws_region + ami_regions = var.ami_regions + ami_groups = var.ami_groups + source_ami_filter { + filters = { + name = "al2023-ami-2023.*-kernel-*-x86_64" + root-device-type = "ebs" + virtualization-type = "hvm" + } + most_recent = true + owners = ["amazon"] + } + ssh_username = "ec2-user" +} + +build { + name = "opengris-scaler-build" + sources = ["source.amazon-ebs.opengris-scaler"] + + provisioner "shell" { + inline = [ + "sudo dnf update -y", + "sudo dnf install -y python${var.python_version} python${var.python_version}-pip", + "sudo python${var.python_version} -m venv /opt/opengris-scaler", + "sudo /opt/opengris-scaler/bin/python -m pip install --upgrade pip", + "sudo /opt/opengris-scaler/bin/pip install opengris-scaler==${var.version}", + "sudo ln -sf /opt/opengris-scaler/bin/scaler_* /usr/local/bin/", + "sudo ln -sf /opt/opengris-scaler/bin/python /usr/local/bin/opengris-python" + ] + } +} diff --git a/src/scaler/worker_manager_adapter/orb_aws_ec2/exception.py b/src/scaler/worker_manager_adapter/orb_aws_ec2/exception.py new file mode 100644 index 000000000..4605e1441 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb_aws_ec2/exception.py @@ -0,0 +1,9 @@ +from typing import Any + + +class ORBAWSEC2Exception(Exception): + """Exception raised for errors in ORB AWS EC2 operations.""" + + def __init__(self, data: Any): + self.data = data + super().__init__(f"ORB AWS EC2 Exception: {data}") diff --git a/src/scaler/worker_manager_adapter/orb_aws_ec2/worker_manager.py b/src/scaler/worker_manager_adapter/orb_aws_ec2/worker_manager.py new file mode 100644 index 000000000..081a713ea --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb_aws_ec2/worker_manager.py @@ -0,0 +1,424 @@ +import asyncio +import logging +import os +import signal +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import boto3 +import zmq + +from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig +from scaler.io import ymq +from scaler.io.mixins import AsyncConnector +from scaler.io.utility import create_async_connector, create_async_simple_context +from scaler.protocol.python.message import ( + Message, + WorkerManagerCommand, + WorkerManagerCommandResponse, + WorkerManagerCommandType, + WorkerManagerHeartbeat, + WorkerManagerHeartbeatEcho, +) +from scaler.utility.event_loop import create_async_loop_routine, register_event_loop, run_task_forever +from scaler.utility.identifiers import WorkerID +from scaler.utility.logging.utility import setup_logger +from scaler.worker_manager_adapter.common import format_capabilities + +Status = WorkerManagerCommandResponse.Status +logger = logging.getLogger(__name__) + + +# Polling configuration for ORB AWS EC2 machine requests +ORB_AWS_EC2_POLLING_INTERVAL_SECONDS = 5 +ORB_AWS_EC2_MAX_POLLING_ATTEMPTS = 60 + + +def get_orb_aws_ec2_worker_name(instance_id: str) -> str: + """ + Returns the deterministic worker name for an ORB AWS EC2 instance. + If instance_id is the bash variable '${INSTANCE_ID}', it returns a bash-compatible string. + """ + if instance_id == "${INSTANCE_ID}": + return "Worker|ORB|${INSTANCE_ID}|${INSTANCE_ID//i-/}" + tag = instance_id.replace("i-", "") + return f"Worker|ORB|{instance_id}|{tag}" + + +class ORBAWSEC2WorkerAdapter: + _config: ORBAWSEC2WorkerAdapterConfig + _sdk: Optional[Any] + _workers: Dict[WorkerID, str] + _template_id: str + _created_security_group_id: Optional[str] + _created_key_name: Optional[str] + _ec2: Optional[Any] + + def __init__(self, config: ORBAWSEC2WorkerAdapterConfig): + self._config = config + self._address = config.worker_manager_config.scheduler_address + self._heartbeat_interval_seconds = config.worker_config.heartbeat_interval_seconds + self._capabilities = config.worker_config.per_worker_capabilities.capabilities + self._max_task_concurrency = config.worker_manager_config.max_task_concurrency + + self._event_loop = config.worker_config.event_loop + self._logging_paths = config.logging_config.paths + self._logging_level = config.logging_config.level + self._logging_config_file = config.logging_config.config_file + + self._sdk: Optional[Any] = None + self._ec2: Optional[Any] = None + self._context = None + self._connector_external: Optional[AsyncConnector] = None + self._created_security_group_id: Optional[str] = None + self._created_key_name: Optional[str] = None + self._cleaned_up = False + self._workers: Dict[WorkerID, str] = {} + self._ident: bytes = b"worker_manager_orb_aws_ec2|uninitialized" + self._subnet_id: Optional[str] = None + + def _build_app_config(self) -> dict: + region = self._config.aws_region or "us-east-1" + return { + "provider": { + "selection_policy": "FIRST_AVAILABLE", + "providers": [ + {"name": "aws-default", "type": "aws", "enabled": True, "priority": 1, "config": {"region": region}} + ], + # ORB skips loading strategy defaults (aws_defaults.json) when config_dict is + # provided, so provider_defaults must be included explicitly here. Without it, + # get_effective_handlers() returns {} and RunInstances is not in supported_apis. + "provider_defaults": { + "aws": { + "handlers": { + "RunInstances": { + "handler_class": "RunInstancesHandler", + "supports_spot": False, + "supports_ondemand": True, + } + } + } + }, + }, + "storage": {"type": "json"}, + } + + async def __setup(self) -> None: + """Set up AWS resources and the ORB template after the SDK is initialised.""" + region = self._config.aws_region or "us-east-1" + self._ec2 = boto3.client("ec2", region_name=region) + self._subnet_id = self._config.subnet_id or self._discover_default_subnet() + self._template_id = os.urandom(8).hex() + + security_group_ids = self._config.security_group_ids + if not security_group_ids: + self._create_security_group() + security_group_ids = [self._created_security_group_id] + + key_name = self._config.key_name + if not key_name: + self._create_key_pair() + key_name = self._created_key_name + + user_data = self._create_user_data() + + create_result = await self._sdk.create_template( + template_id=self._template_id, + name=f"opengris-orb-{self._template_id}", + image_id=self._config.image_id, + provider_api="RunInstances", + instance_type=self._config.instance_type, + max_instances=self._config.worker_manager_config.max_task_concurrency, + provider_name="aws-default", + machine_types={self._config.instance_type: 1}, + subnet_ids=[self._subnet_id], + security_group_ids=security_group_ids, + key_name=key_name, + user_data=user_data, + ) + logger.info(f"create_template result: {create_result}") + + validate_result = await self._sdk.validate_template(template_id=self._template_id) + logger.info(f"validate_template result: {validate_result}") + + self._context = create_async_simple_context() + self._name = "worker_manager_orb_aws_ec2" + self._ident = f"{self._name}|{uuid.uuid4().bytes.hex()}".encode() + + self._connector_external = create_async_connector( + self._context, + name=self._name, + socket_type=zmq.DEALER, + address=self._address, + bind_or_connect="connect", + callback=self.__on_receive_external, + identity=self._ident, + ) + + async def __terminate_all_workers(self) -> None: + """Return all active instances to ORB before the SDK context exits.""" + if not self._workers or self._sdk is None: + return + instance_ids = list(self._workers.values()) + logger.info(f"Terminating {len(instance_ids)} worker group(s)...") + try: + await self._sdk.create_return_request(machine_ids=instance_ids) + logger.info(f"Successfully requested termination of instances: {instance_ids}") + except Exception as e: + logger.warning(f"Failed to terminate instances during cleanup: {e}") + self._workers.clear() + + async def __on_receive_external(self, message: Message): + if isinstance(message, WorkerManagerCommand): + await self._handle_command(message) + elif isinstance(message, WorkerManagerHeartbeatEcho): + pass + else: + logging.warning(f"Received unknown message type: {type(message)}") + + async def _handle_command(self, command: WorkerManagerCommand): + cmd_type = command.command + response_status = Status.Success + worker_ids: List[bytes] = [] + capabilities: Dict[str, int] = {} + + if cmd_type == WorkerManagerCommandType.StartWorkers: + worker_ids, response_status = await self.start_worker() + if response_status == Status.Success: + capabilities = self._capabilities + elif cmd_type == WorkerManagerCommandType.ShutdownWorkers: + worker_ids, response_status = await self.shutdown_workers(list(command.worker_ids)) + else: + raise ValueError("Unknown Command") + + assert self._connector_external is not None + await self._connector_external.send( + WorkerManagerCommandResponse.new_msg( + command=cmd_type, status=response_status, worker_ids=worker_ids, capabilities=capabilities + ) + ) + + async def __send_heartbeat(self): + assert self._connector_external is not None + await self._connector_external.send( + WorkerManagerHeartbeat.new_msg( + max_task_concurrency=self._max_task_concurrency, + capabilities=self._capabilities, + worker_manager_id=self._ident, + ) + ) + + def run(self) -> None: + self._loop = asyncio.new_event_loop() + run_task_forever(self._loop, self._run(), cleanup_callback=self._cleanup) + + def __destroy(self): + print(f"Worker adapter {self._ident!r} received signal, shutting down") + self._task.cancel() + + def __register_signal(self): + self._loop.add_signal_handler(signal.SIGINT, self.__destroy) + self._loop.add_signal_handler(signal.SIGTERM, self.__destroy) + + async def _run(self) -> None: + from orb import ORBClient as orb + + register_event_loop(self._event_loop) + setup_logger(self._logging_paths, self._logging_config_file, self._logging_level) + + async with orb(app_config=self._build_app_config()) as sdk: + self._sdk = sdk + await self.__setup() + self._task = self._loop.create_task(self.__get_loops()) + self.__register_signal() + try: + await self._task + except asyncio.CancelledError: + pass + finally: + await self.__terminate_all_workers() + + self._sdk = None + + async def __get_loops(self): + assert self._connector_external is not None + loops = [ + create_async_loop_routine(self._connector_external.routine, 0), + create_async_loop_routine(self.__send_heartbeat, self._heartbeat_interval_seconds), + ] + + try: + await asyncio.gather(*loops) + except asyncio.CancelledError: + pass + except ymq.YMQException as e: + if e.code == ymq.ErrorCode.ConnectorSocketClosedByRemoteEnd: + pass + else: + logging.exception(f"{self._ident!r}: failed with unhandled exception:\n{e}") + + def _create_user_data(self) -> str: + worker_config = self._config.worker_config + adapter_config = self._config.worker_manager_config + + # NOTE: --max-task-concurrency is not passed; scaler_worker_manager defaults to cpu_count - 1 workers, + # where cpu_count is determined by the machine type configured by the user. + script = f"""#!/bin/bash +INSTANCE_ID=$(ec2-metadata --instance-id --quiet) +nohup /usr/local/bin/scaler_worker_manager baremetal_native {adapter_config.scheduler_address.to_address()} \ + --mode fixed \ + --worker-type ORB \ + --worker-manager-id "${{INSTANCE_ID}}" \ + --per-worker-task-queue-size {worker_config.per_worker_task_queue_size} \ + --heartbeat-interval-seconds {worker_config.heartbeat_interval_seconds} \ + --task-timeout-seconds {worker_config.task_timeout_seconds} \ + --garbage-collect-interval-seconds {worker_config.garbage_collect_interval_seconds} \ + --death-timeout-seconds {worker_config.death_timeout_seconds} \ + --trim-memory-threshold-bytes {worker_config.trim_memory_threshold_bytes} \ + --event-loop {self._config.worker_config.event_loop} \ + --io-threads {self._config.worker_config.io_threads}""" + + if worker_config.hard_processor_suspend: + script += " \ + --hard-processor-suspend" + + if adapter_config.object_storage_address: + script += f" \ + --object-storage-address {adapter_config.object_storage_address.to_string()}" + + capabilities = worker_config.per_worker_capabilities.capabilities + if capabilities: + cap_str = format_capabilities(capabilities) + if cap_str.strip(): + script += f" \ + --per-worker-capabilities {cap_str}" + + script += " > /var/log/opengris-scaler.log 2>&1 &\n" + + return script + + def _discover_default_subnet(self) -> str: + vpcs = self._ec2.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}]) + if not vpcs["Vpcs"]: + raise RuntimeError("No default VPC found, and no subnet_id provided.") + default_vpc_id = vpcs["Vpcs"][0]["VpcId"] + + subnets = self._ec2.describe_subnets(Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}]) + if not subnets["Subnets"]: + raise RuntimeError(f"No subnets found in default VPC {default_vpc_id}.") + + subnet_id = subnets["Subnets"][0]["SubnetId"] + logger.info(f"Auto-discovered subnet_id: {subnet_id}") + return subnet_id + + def _create_security_group(self): + # Get VPC ID from Subnet + subnet_response = self._ec2.describe_subnets(SubnetIds=[self._subnet_id]) + vpc_id = subnet_response["Subnets"][0]["VpcId"] + + # Create Security Group (outbound-only — workers connect out to scheduler via ZMQ) + group_name = f"opengris-orb-sg-{self._template_id}" + sg_response = self._ec2.create_security_group( + Description="Temporary security group created for OpenGRIS ORB worker adapter", + GroupName=group_name, + VpcId=vpc_id, + ) + self._created_security_group_id = sg_response["GroupId"] + logger.info(f"Created security group with ID: {self._created_security_group_id}") + + def _create_key_pair(self): + key_name = f"opengris-orb-key-{self._template_id}" + self._ec2.create_key_pair(KeyName=key_name) + self._created_key_name = key_name + logger.info(f"Created key pair: {key_name}") + + def _cleanup(self): + if self._cleaned_up: + return + self._cleaned_up = True + + if self._connector_external is not None: + self._connector_external.destroy() + + logger.info("Starting cleanup of AWS resources...") + + if self._created_security_group_id is not None: + try: + logger.info(f"Deleting AWS security group: {self._created_security_group_id}") + self._ec2.delete_security_group(GroupId=self._created_security_group_id) + except Exception as e: + logger.warning(f"Failed to delete security group {self._created_security_group_id}: {e}") + + if self._created_key_name is not None: + try: + logger.info(f"Deleting AWS key pair: {self._created_key_name}") + self._ec2.delete_key_pair(KeyName=self._created_key_name) + except Exception as e: + logger.warning(f"Failed to delete key pair {self._created_key_name}: {e}") + + logger.info("Cleanup completed.") + + def __del__(self): + self._cleanup() + + async def start_worker(self) -> Tuple[List[bytes], Status]: + if len(self._workers) >= self._max_task_concurrency != -1: + return [], Status.TooManyWorkers + + response = await self._sdk.create_request(template_id=self._template_id, count=1) + request_id = ( + response.get("created_request_id") or response.get("request_id") or response.get("id") + if isinstance(response, dict) + else None + ) + + if not request_id: + logger.error(f"ORB machine request failed to return a request ID. Response: {response}") + return [], Status.UnknownAction + + logger.info(f"ORB machine request {request_id} submitted, waiting for instance IDs...") + + timeout = float(ORB_AWS_EC2_MAX_POLLING_ATTEMPTS * ORB_AWS_EC2_POLLING_INTERVAL_SECONDS) + try: + final = await self._sdk.wait_for_request( + request_id, timeout=timeout, poll_interval=float(ORB_AWS_EC2_POLLING_INTERVAL_SECONDS) + ) + except TimeoutError: + logger.error(f"ORB machine request {request_id} timed out after {timeout:.0f}s.") + return [], Status.UnknownAction + + machines = final.get("machines", []) if isinstance(final, dict) else [] + instance_id = next( + (m.get("machine_id") or m.get("id") for m in machines if m.get("machine_id") or m.get("id")), None + ) + + if not instance_id: + status = final.get("status", "") if isinstance(final, dict) else "" + logger.error(f"ORB request {request_id} completed with status '{status}' but no instance ID found.") + return [], Status.UnknownAction + + logger.info(f"ORB request {request_id} fulfilled with instance ID: {instance_id}") + worker_id = WorkerID(get_orb_aws_ec2_worker_name(instance_id).encode()) + self._workers[worker_id] = instance_id + return [bytes(worker_id)], Status.Success + + async def shutdown_workers(self, worker_ids: List[bytes]) -> Tuple[List[bytes], Status]: + if not worker_ids: + return [], Status.WorkerNotFound + + instance_ids = [] + affected_worker_ids = [] + for wid_bytes in worker_ids: + worker_id = WorkerID(wid_bytes) + if worker_id not in self._workers: + logger.warning(f"Worker with ID {wid_bytes!r} does not exist.") + return [], Status.WorkerNotFound + instance_ids.append(self._workers[worker_id]) + affected_worker_ids.append(wid_bytes) + + await self._sdk.create_return_request(machine_ids=instance_ids) + + for wid_bytes in affected_worker_ids: + del self._workers[WorkerID(wid_bytes)] + + return affected_worker_ids, Status.Success diff --git a/tests/entry_points/test_all.py b/tests/entry_points/test_all.py index 1a6c016d5..33b6935e4 100644 --- a/tests/entry_points/test_all.py +++ b/tests/entry_points/test_all.py @@ -231,6 +231,54 @@ def test_mode_value_based_string(self) -> None: config = self._parse(self._native_base(mode="fixed")) self.assertEqual(config.worker_managers[0].mode, NativeWorkerManagerMode.FIXED) + def test_orb_aws_ec2_worker_manager_parsed_from_toml(self) -> None: + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + + toml = { + "worker_manager": { + "type": "orb_aws_ec2", + "scheduler_address": "tcp://127.0.0.1:6378", + "image_id": "ami-0528819f94f4f5fa5", + } + } + config = self._parse(toml) + self.assertEqual(len(config.worker_managers), 1) + self.assertIsInstance(config.worker_managers[0], ORBAWSEC2WorkerAdapterConfig) + + def test_orb_aws_ec2_fields_from_toml(self) -> None: + toml = { + "worker_manager": { + "type": "orb_aws_ec2", + "scheduler_address": "tcp://127.0.0.1:6378", + "image_id": "ami-0528819f94f4f5fa5", + "instance_type": "t3.medium", + "aws_region": "eu-west-1", + } + } + config = self._parse(toml) + self.assertEqual(config.worker_managers[0].image_id, "ami-0528819f94f4f5fa5") + self.assertEqual(config.worker_managers[0].instance_type, "t3.medium") + self.assertEqual(config.worker_managers[0].aws_region, "eu-west-1") + + def test_mixed_native_and_orb_aws_ec2_worker_managers(self) -> None: + from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + + toml = { + "worker_manager": [ + {"type": "baremetal_native", "scheduler_address": "tcp://127.0.0.1:6378", "worker_manager_id": "wm-1"}, + { + "type": "orb_aws_ec2", + "scheduler_address": "tcp://127.0.0.1:6378", + "image_id": "ami-0528819f94f4f5fa5", + }, + ] + } + config = self._parse(toml) + self.assertEqual(len(config.worker_managers), 2) + self.assertIsInstance(config.worker_managers[0], NativeWorkerManagerConfig) + self.assertIsInstance(config.worker_managers[1], ORBAWSEC2WorkerAdapterConfig) + class TestRunWorkerManager(unittest.TestCase): """Tests that _run_worker_manager calls register_event_loop and setup_logger from the per-manager config.""" @@ -275,3 +323,45 @@ def test_setup_logger_called_with_logging_config(self) -> None: _run_worker_manager(config) mock_log.assert_called_once_with(config.logging_config.paths, config.logging_config.config_file, "WARNING") + + def _make_orb_aws_ec2_config(self, event_loop="builtin", logging_level="INFO"): + from scaler.config.common.logging import LoggingConfig + from scaler.config.common.worker import WorkerConfig + from scaler.config.common.worker_manager import WorkerManagerConfig + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + from scaler.config.types.zmq import ZMQConfig + + wmc = WorkerManagerConfig(scheduler_address=ZMQConfig.from_string("tcp://localhost:6378")) + return ORBAWSEC2WorkerAdapterConfig( + worker_manager_config=wmc, + image_id="ami-0528819f94f4f5fa5", + worker_config=WorkerConfig(event_loop=event_loop), + logging_config=LoggingConfig(level=logging_level), + ) + + def test_orb_aws_ec2_run_worker_manager_dispatches_correctly(self) -> None: + from scaler.entry_points.scaler import _run_worker_manager + + config = self._make_orb_aws_ec2_config() + + with patch("scaler.entry_points.scaler.setup_logger"), patch( + "scaler.entry_points.scaler.register_event_loop" + ), patch("scaler.worker_manager_adapter.orb_aws_ec2.worker_manager.ORBAWSEC2WorkerAdapter") as mock_orb: + mock_orb.return_value.run.return_value = None + _run_worker_manager(config) + + mock_orb.assert_called_once_with(config) + mock_orb.return_value.run.assert_called_once() + + def test_orb_aws_ec2_register_event_loop_called(self) -> None: + from scaler.entry_points.scaler import _run_worker_manager + + config = self._make_orb_aws_ec2_config(event_loop="builtin") + + with patch("scaler.entry_points.scaler.register_event_loop") as mock_reg, patch( + "scaler.entry_points.scaler.setup_logger" + ), patch("scaler.worker_manager_adapter.orb_aws_ec2.worker_manager.ORBAWSEC2WorkerAdapter") as mock_orb: + mock_orb.return_value.run.return_value = None + _run_worker_manager(config) + + mock_reg.assert_called_once_with("builtin") diff --git a/tests/entry_points/test_worker_manager.py b/tests/entry_points/test_worker_manager.py index f341770a0..caec24b73 100644 --- a/tests/entry_points/test_worker_manager.py +++ b/tests/entry_points/test_worker_manager.py @@ -251,6 +251,50 @@ def test_per_manager_config_defaults(self) -> None: self.assertEqual(config.worker_config.event_loop, WorkerConfig().event_loop) +_ORB_AWS_EC2_BASE_ARGV = ["tcp://127.0.0.1:6378", "--image-id", "ami-0528819f94f4f5fa5"] + + +class TestORBAWSEC2WorkerManagerSubcommand(unittest.TestCase): + """Tests that ORBAWSEC2WorkerAdapterConfig is correctly parsed via parse_with_section.""" + + def test_orb_aws_ec2_image_id_parsed(self) -> None: + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + + config = ORBAWSEC2WorkerAdapterConfig.parse_with_section( + "scaler_worker_manager", {}, argv=_ORB_AWS_EC2_BASE_ARGV + ) + self.assertIsInstance(config, ORBAWSEC2WorkerAdapterConfig) + self.assertEqual(config.image_id, "ami-0528819f94f4f5fa5") + + def test_orb_aws_ec2_defaults(self) -> None: + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + + config = ORBAWSEC2WorkerAdapterConfig.parse_with_section( + "scaler_worker_manager", {}, argv=_ORB_AWS_EC2_BASE_ARGV + ) + self.assertEqual(config.instance_type, "t2.micro") + self.assertEqual(config.aws_region, "us-east-1") + + def test_orb_aws_ec2_instance_type_and_region_from_cli(self) -> None: + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + + config = ORBAWSEC2WorkerAdapterConfig.parse_with_section( + "scaler_worker_manager", + {}, + argv=[*_ORB_AWS_EC2_BASE_ARGV, "--instance-type", "t3.medium", "--aws-region", "eu-west-1"], + ) + self.assertEqual(config.instance_type, "t3.medium") + self.assertEqual(config.aws_region, "eu-west-1") + + def test_orb_aws_ec2_logging_level_from_cli(self) -> None: + from scaler.config.section.orb_aws_ec2_worker_adapter import ORBAWSEC2WorkerAdapterConfig + + config = ORBAWSEC2WorkerAdapterConfig.parse_with_section( + "scaler_worker_manager", {}, argv=[*_ORB_AWS_EC2_BASE_ARGV, "--logging-level", "DEBUG"] + ) + self.assertEqual(config.logging_config.level, "DEBUG") + + class TestWorkerManagerMain(unittest.TestCase): """Tests for the main() entry point dispatch and error handling."""