diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index 6cdef7357..cd9348a80 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -55,7 +55,13 @@ runs: run: | uv pip install --system -r examples/applications/requirements_applications.txt uv pip install --system -r examples/ray_compat/requirements.txt + readarray -t skip_examples < examples/skip_examples.txt for example in "./examples"/*.py; do + filename=$(basename "$example") + if [[ " ${skip_examples[*]} " =~ [[:space:]]${filename}[[:space:]] ]]; then + echo "Skipping $example" + continue + fi echo "Running $example" python $example done diff --git a/.gitignore b/.gitignore index 1da15adde..21014dd39 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,9 @@ CMakeFiles/ src/scaler/protocol/capnp/*.c++ src/scaler/protocol/capnp/*.h +orb/logs/ +orb/metrics/ + # AWS HPC test-generated files .scaler_aws_batch_config.json .scaler_aws_hpc.env diff --git a/README.md b/README.md index 425ad4be7..cfa7ff3fe 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ This will start a scheduler with 4 workers on port `2345`. ### Setting up a computing cluster from the CLI The object storage server, scheduler and workers can also be started from the command line with -`scaler_scheduler` and `scaler_cluster`. +`scaler_scheduler` and `scaler_worker_manager`. First, start the scheduler, and make it connect to the object storage server: @@ -132,28 +132,22 @@ $ scaler_scheduler "tcp://127.0.0.1:2345" ... ``` -Finally, start a set of workers (a.k.a. a Scaler *cluster*) that connects to the previously started scheduler: +Finally, start a set of workers that connects to the previously started scheduler: ```bash -$ scaler_cluster -n 4 tcp://127.0.0.1:2345 -[INFO]2023-03-19 12:19:19-0400: logging to ('/dev/stdout',) -[INFO]2023-03-19 12:19:19-0400: ClusterProcess: starting 4 workers, heartbeat_interval_seconds=2, object_retention_seconds=3600 -[INFO]2023-03-19 12:19:19-0400: Worker[0] started -[INFO]2023-03-19 12:19:19-0400: Worker[1] started -[INFO]2023-03-19 12:19:19-0400: Worker[2] started -[INFO]2023-03-19 12:19:19-0400: Worker[3] started +$ scaler_worker_manager native --mode fixed --max-task-concurrency 4 tcp://127.0.0.1:2345 ... ``` -Multiple Scaler clusters can be connected to the same scheduler, providing distributed computation over multiple +Multiple worker managers can be connected to the same scheduler, providing distributed computation over multiple servers. -`-h` lists the available options for the object storage server, scheduler and the cluster executables: +`-h` lists the available options for the object storage server, scheduler and the worker manager executables: ```bash $ scaler_object_storage_server -h $ scaler_scheduler -h -$ scaler_cluster -h +$ scaler_worker_manager native --help ``` ### Submitting Python tasks using the Scaler client @@ -243,12 +237,14 @@ The following table maps each Scaler command to its corresponding section name i | Command | TOML Section Name | |--------------------------------------|---------------------------------| | `scaler_scheduler` | `[scheduler]` | -| `scaler_cluster` | `[cluster]` | | `scaler_object_storage_server` | `[object_storage_server]` | | `scaler_ui` | `[webui]` | | `scaler_top` | `[top]` | -| `scaler_worker_manager_baremetal_native` | `[native_worker_manager]` | -| `scaler_worker_manager_symphony` | `[symphony_worker_manager]` | +| `scaler_worker_manager native` | `[native_worker_manager]` | +| `scaler_worker_manager symphony` | `[symphony_worker_manager]` | +| `scaler_worker_manager ecs` | `[ecs_worker_manager]` | +| `scaler_worker_manager hpc` | `[aws_hpc_worker_manager]` | +| `scaler_worker_manager orb` | `[orb_worker_adapter]` | ### Practical Scenarios & Examples @@ -269,8 +265,9 @@ logging_paths = ["/dev/stdout", "/var/log/scaler/scheduler.log"] policy_engine_type = "simple" policy_content = "allocate=even_load; scaling=no" -[cluster] -num_of_workers = 8 +[native_worker_manager] +mode = "fixed" +max_task_concurrency = 8 per_worker_capabilities = "linux,cpu=8" task_timeout_seconds = 600 @@ -285,7 +282,7 @@ With this single file, starting your entire stack is simple and consistent: ```bash scaler_object_storage_server tcp://127.0.0.1:6379 --config example_config.toml & scaler_scheduler tcp://127.0.0.1:6378 --config example_config.toml & -scaler_cluster tcp://127.0.0.1:6378 --config example_config.toml & +scaler_worker_manager native tcp://127.0.0.1:6378 --config example_config.toml & scaler_ui tcp://127.0.0.1:6380 --config example_config.toml & ``` @@ -295,12 +292,12 @@ You can override any value from the TOML file by providing it as a command-line example_config.toml file but test the cluster with 12 workers instead of 8: ```bash -# The --num-of-workers flag will take precedence over the [cluster] section -scaler_cluster tcp://127.0.0.1:6378 --config example_config.toml --num-of-workers 12 +# The --max-task-concurrency flag will take precedence over the [native_worker_manager] section +scaler_worker_manager native tcp://127.0.0.1:6378 --config example_config.toml --max-task-concurrency 12 ``` The cluster will start with 12 workers, but all other settings (like `task_timeout_seconds`) will still be loaded from the -`[cluster]` section of example_config.toml. +`[native_worker_manager]` section of example_config.toml. ## Nested computations @@ -351,7 +348,7 @@ When starting a cluster of workers, you can define the capabilities available on capabilities these provide. ```bash -$ scaler_cluster -n 4 --per-worker-capabilities "gpu,linux" tcp://127.0.0.1:2345 +$ scaler_worker_manager native --mode fixed --max-task-concurrency 4 --per-worker-capabilities "gpu,linux" tcp://127.0.0.1:2345 ``` ### Specifying Capability Requirements for Tasks @@ -380,7 +377,7 @@ might be added in the future. A Scaler scheduler can interface with IBM Spectrum Symphony to provide distributed computing across Symphony clusters. ```bash -$ scaler_worker_manager_symphony tcp://127.0.0.1:2345 --service-name ScalerService --base-concurrency 4 +$ scaler_worker_manager symphony tcp://127.0.0.1:2345 --service-name ScalerService --base-concurrency 4 ``` This will start a Scaler worker that connects to the Scaler scheduler at `tcp://127.0.0.1:2345` and uses the Symphony @@ -465,6 +462,26 @@ where `deepest_nesting_level` is the deepest nesting level a task has in your wo workload that has a base task that calls a nested task that calls another nested task, then the deepest nesting level is 2. +## ORB (AWS EC2) integration + +A Scaler scheduler can interface with ORB (Open Resource Broker) to dynamically provision and manage workers on AWS EC2 instances. + +```bash +$ scaler_worker_manager orb tcp://127.0.0.1:2345 --image-id ami-0528819f94f4f5fa5 +``` + +This will start an ORB worker adapter that connects to the Scaler scheduler at `tcp://127.0.0.1:2345`. The scheduler can then request new workers from this adapter, which will be launched as EC2 instances. + +### Configuration + +The ORB adapter requires `orb-py` and `boto3` to be installed. You can install them with: + +```bash +$ pip install "opengris-scaler[orb]" +``` + +For more details on configuring ORB, including AWS credentials and instance templates, please refer to the [ORB Worker Adapter documentation](https://finos.github.io/opengris-scaler/tutorials/worker_manager_adapter/orb.html). + ## Worker Manager usage > **Note**: This feature is experimental and may change in future releases. @@ -480,7 +497,7 @@ specification [here](https://github.com/finos/opengris). Start a Native Worker Manager and connect it to the scheduler: ```bash -$ scaler_worker_manager_baremetal_native tcp://127.0.0.1:2345 +$ scaler_worker_manager native tcp://127.0.0.1:2345 ``` To check that the Worker Manager is working, you can bring up `scaler_top` to see workers spawning and terminating as diff --git a/docs/source/index.rst b/docs/source/index.rst index 318e67504..0d6a12373 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,8 +29,11 @@ Content tutorials/scaling tutorials/worker_manager_adapter/index tutorials/worker_manager_adapter/native + tutorials/worker_manager_adapter/orb tutorials/worker_manager_adapter/aws_hpc/index tutorials/worker_manager_adapter/common_parameters + tutorials/worker_manager_adapter/worker_manager + tutorials/worker_manager_adapter/aio tutorials/compatibility/ray tutorials/configuration tutorials/examples diff --git a/docs/source/tutorials/compatibility/ray.rst b/docs/source/tutorials/compatibility/ray.rst index 8a2ac3d5d..0f59cce34 100644 --- a/docs/source/tutorials/compatibility/ray.rst +++ b/docs/source/tutorials/compatibility/ray.rst @@ -6,7 +6,7 @@ Ray Compatibility Layer Scaler is a lightweight distributed computation engine similar to Ray. Scaler supports many of the same concepts as Ray including remote functions (known as tasks in Scaler), futures, cluster object storage, labels (known as capabilities in Scaler), and it comes with comparable monitoring tools. -Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including AWS EC2 and IBM Symphony, +Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including ORB (AWS EC2) and IBM Symphony, with more providers planned for the future. You can view our `roadmap on GitHub `_ for details on upcoming cloud integrations. diff --git a/docs/source/tutorials/configuration.rst b/docs/source/tutorials/configuration.rst index 6db0b8e3a..49a0d39a4 100644 --- a/docs/source/tutorials/configuration.rst +++ b/docs/source/tutorials/configuration.rst @@ -78,14 +78,14 @@ For the list of available settings, use the CLI command: .. code:: bash - scaler_cluster -h + scaler_worker_manager native --help **Preload Hook** Workers can execute an optional initialization function before processing tasks using the ``--preload`` option. This enables workers to: * Set up environments on demand -* Preload data, libraries, or models +* Preload data, libraries, or models * Initialize connections or state The preload specification follows the format ``module.path:function(args, kwargs)`` where: @@ -97,10 +97,10 @@ The preload specification follows the format ``module.path:function(args, kwargs .. code:: bash # Simple function call with no arguments - scaler_cluster tcp://127.0.0.1:8516 --preload "mypackage.init:setup" - + scaler_worker_manager native tcp://127.0.0.1:8516 --preload "mypackage.init:setup" + # Function call with arguments - scaler_cluster tcp://127.0.0.1:8516 --preload "mypackage.init:configure('production', debug=False)" + scaler_worker_manager native tcp://127.0.0.1:8516 --preload "mypackage.init:configure('production', debug=False)" The preload function is executed once per processor during initialization, before any tasks are processed. If the preload function fails, the error is logged and the processor will terminate. @@ -127,7 +127,7 @@ This can be set using the CLI: .. code:: bash - scaler_cluster -n 10 tcp://127.0.0.1:8516 -dts 300 + scaler_worker_manager native --mode fixed --max-task-concurrency 10 tcp://127.0.0.1:8516 -dts 300 Or through the programmatic API: @@ -185,22 +185,22 @@ The following table maps each Scaler command to its corresponding section name i - TOML Section Name * - ``scaler_scheduler`` - ``[scheduler]`` - * - ``scaler_cluster`` - - ``[cluster]`` * - ``scaler_object_storage_server`` - ``[object_storage_server]`` * - ``scaler_ui`` - ``[webui]`` * - ``scaler_top`` - ``[top]`` - * - ``scaler_worker_manager_baremetal_native`` + * - ``scaler_worker_manager native`` - ``[native_worker_manager]`` - * - ``scaler_worker_manager_symphony`` + * - ``scaler_worker_manager symphony`` - ``[symphony_worker_manager]`` - * - ``scaler_worker_manager_aws_raw_ecs`` + * - ``scaler_worker_manager ecs`` - ``[ecs_worker_manager]`` - * - ``python -m scaler.entry_points.worker_manager_aws_hpc_batch`` + * - ``scaler_worker_manager hpc`` - ``[aws_hpc_worker_manager]`` + * - ``scaler_worker_manager orb`` + - ``[orb_worker_adapter]`` Practical Scenarios & Examples @@ -224,7 +224,8 @@ Here is an example of a single ``example_config.toml`` file that configures mult policy_engine_type = "simple" policy_content = "allocate=even_load; scaling=no" - [cluster] + [native_worker_manager] + mode = "fixed" max_task_concurrency = 8 per_worker_capabilities = "linux,cpu=8" task_timeout_seconds = 600 @@ -241,7 +242,7 @@ With this single file, starting your entire stack is simple and consistent: scaler_object_storage_server tcp://127.0.0.1:6379 --config example_config.toml & scaler_scheduler tcp://127.0.0.1:6378 --config example_config.toml & - scaler_cluster tcp://127.0.0.1:6378 --config example_config.toml & + scaler_worker_manager native tcp://127.0.0.1:6378 --config example_config.toml & scaler_ui tcp://127.0.0.1:6380 --config example_config.toml & @@ -251,10 +252,10 @@ You can override any value from the TOML file by providing it as a command-line .. code-block:: bash - # The --max-task-concurrency flag will take precedence over the [cluster] section - scaler_cluster tcp://127.0.0.1:6378 --config example_config.toml --max-task-concurrency 12 + # The --max-task-concurrency flag will take precedence over the [native_worker_manager] section + scaler_worker_manager native tcp://127.0.0.1:6378 --config example_config.toml --max-task-concurrency 12 -The cluster will start with **12 workers**, but all other settings (like ``task_timeout_seconds``) will still be loaded from the ``[cluster]`` section of ``example_config.toml``. +The cluster will start with **12 workers**, but all other settings (like ``task_timeout_seconds``) will still be loaded from the ``[native_worker_manager]`` section of ``example_config.toml``. **Scenario 3: Waterfall Scaling Configuration** @@ -276,10 +277,10 @@ To use the ``waterfall_v1`` policy engine for priority-based scaling across mult 2, ECS, 50 """ - [native_worker_adapter] + [native_worker_manager] max_task_concurrency = 8 - [ecs_worker_adapter] + [ecs_worker_manager] max_task_concurrency = 50 Then start the scheduler and worker adapters: @@ -287,7 +288,7 @@ Then start the scheduler and worker adapters: .. code-block:: bash scaler_scheduler tcp://127.0.0.1:8516 --config waterfall_config.toml & - scaler_worker_adapter_native tcp://127.0.0.1:8516 --config waterfall_config.toml & - scaler_worker_adapter_ecs tcp://127.0.0.1:8516 --config waterfall_config.toml & + scaler_worker_manager native tcp://127.0.0.1:8516 --config waterfall_config.toml & + scaler_worker_manager ecs tcp://127.0.0.1:8516 --config waterfall_config.toml & Local ``NAT`` workers will scale up first. When they reach capacity, ``ECS`` workers will begin scaling. On scale-down, ECS workers drain before local workers. diff --git a/docs/source/tutorials/examples.rst b/docs/source/tutorials/examples.rst index c01f15284..b211baa47 100644 --- a/docs/source/tutorials/examples.rst +++ b/docs/source/tutorials/examples.rst @@ -15,6 +15,14 @@ Shows how to send a basic task to scheduler .. literalinclude:: ../../../examples/simple_client.py :language: python +Submit Tasks +~~~~~~~~~~~~ + +Shows various ways to submit tasks (submit, map, starmap) + +.. literalinclude:: ../../../examples/submit_tasks.py + :language: python + Client Mapping Tasks ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/tutorials/quickstart.rst b/docs/source/tutorials/quickstart.rst index 35dcd9bbb..6c89b4a18 100644 --- a/docs/source/tutorials/quickstart.rst +++ b/docs/source/tutorials/quickstart.rst @@ -137,7 +137,7 @@ Here we use localhost addresses for demonstration, however the scheduler and wor .. code:: bash - scaler_cluster -n 10 tcp://127.0.0.1:8516 + scaler_worker_manager native --mode fixed --max-task-concurrency 10 tcp://127.0.0.1:8516 From here, connect the Python Client and begin submitting tasks: diff --git a/docs/source/tutorials/worker_manager_adapter/aio.rst b/docs/source/tutorials/worker_manager_adapter/aio.rst new file mode 100644 index 000000000..6f4c686b1 --- /dev/null +++ b/docs/source/tutorials/worker_manager_adapter/aio.rst @@ -0,0 +1,77 @@ +All-in-One Entry Point (scaler_aio) +==================================== + +``scaler_aio`` boots the full Scaler stack — scheduler and one or more worker managers — from a single TOML +configuration file. + +Usage +----- + +.. code-block:: bash + + scaler_aio --config + +Each recognised section in the config file spawns a separate process. Unrecognised sections are ignored. +If no recognised sections are found, ``scaler_aio`` exits with an error. + +Example Configuration +--------------------- + +.. code-block:: toml + + [scheduler] + object_storage_address = "tcp://127.0.0.1:6379" + logging_level = "INFO" + policy_content = "allocate=even_load; scaling=vanilla" + + [native_worker_manager] + max_task_concurrency = 4 + +With this file: + +.. code-block:: bash + + scaler_aio --config stack.toml + +This starts the scheduler and one native worker manager as separate processes. + +Array-of-Tables (Multiple Managers of the Same Type) +------------------------------------------------------ + +Use the TOML ``[[section]]`` array-of-tables syntax to spawn multiple instances of the same adapter type: + +.. code-block:: toml + + [scheduler] + object_storage_address = "tcp://127.0.0.1:6379" + + [[native_worker_manager]] + max_task_concurrency = 2 + + [[native_worker_manager]] + max_task_concurrency = 4 + +This spawns two native worker manager processes with different concurrency limits. + +Recognised Section Names +------------------------- + +The following section names are recognised: + +.. list-table:: + :header-rows: 1 + + * - TOML Section + - Component started + * - ``[scheduler]`` + - Scaler scheduler + * - ``[native_worker_manager]`` + - Native (local subprocess) manager + * - ``[symphony_worker_manager]`` + - IBM Symphony manager + * - ``[ecs_worker_manager]`` + - AWS ECS manager + * - ``[aws_hpc_worker_manager]`` + - AWS Batch manager + * - ``[orb_worker_adapter]`` + - ORB (EC2) adapter diff --git a/docs/source/tutorials/worker_manager_adapter/aws_hpc/setup.rst b/docs/source/tutorials/worker_manager_adapter/aws_hpc/setup.rst index 654f58c06..c63860180 100644 --- a/docs/source/tutorials/worker_manager_adapter/aws_hpc/setup.rst +++ b/docs/source/tutorials/worker_manager_adapter/aws_hpc/setup.rst @@ -276,7 +276,7 @@ All commands in the **same terminal** inside the container: source tests/worker_manager_adapter/aws_hpc/.scaler_aws_hpc.env # Start AWS Batch worker in background (default job timeout: 60 minutes) - python -m scaler.entry_points.worker_manager_aws_hpc_batch \ + scaler_worker_manager hpc \ tcp://127.0.0.1:2345 \ --job-queue $SCALER_JOB_QUEUE \ --job-definition $SCALER_JOB_DEFINITION \ @@ -286,7 +286,7 @@ All commands in the **same terminal** inside the container: --logging-level INFO & # To override job timeout (e.g., 10 minutes): - # python -m scaler.entry_points.worker_manager_aws_hpc_batch \ + # scaler_worker_manager hpc \ # tcp://127.0.0.1:2345 \ # ... \ # --job-timeout-minutes 10 & @@ -460,7 +460,7 @@ Provisioner Options | ``--job-timeout`` | 60 | Job timeout in minutes (default: 1 hour, overridden by worker at runtime)| +----------------------+----------------+--------------------------------------------------------------------------+ -AWS HPC Batch Options (``worker_manager_aws_hpc_batch``) +AWS HPC Batch Options (``scaler_worker_manager hpc``) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +------------------------------------+----------------+-------------------------------------------------------------------+ diff --git a/docs/source/tutorials/worker_manager_adapter/common_parameters.rst b/docs/source/tutorials/worker_manager_adapter/common_parameters.rst index 4b65e8700..ebaaad068 100644 --- a/docs/source/tutorials/worker_manager_adapter/common_parameters.rst +++ b/docs/source/tutorials/worker_manager_adapter/common_parameters.rst @@ -1,7 +1,7 @@ Common Worker Manager Parameters ================================ -All worker managers in Scaler share a set of common configuration parameters for connecting to the cluster, configuring the internal web server, and managing worker behavior. +All worker managers in Scaler share a set of common configuration parameters for connecting to the cluster and managing worker behavior. .. note:: For more details on how to configure Scaler, see the :doc:`../configuration` section. diff --git a/docs/source/tutorials/worker_manager_adapter/index.rst b/docs/source/tutorials/worker_manager_adapter/index.rst index b4dcddc15..2df18eb67 100644 --- a/docs/source/tutorials/worker_manager_adapter/index.rst +++ b/docs/source/tutorials/worker_manager_adapter/index.rst @@ -22,7 +22,7 @@ Once the scheduler is running with this policy, start a worker manager (e.g., th .. code-block:: bash - scaler_worker_manager_baremetal_native tcp://127.0.0.1:8516 --max-task-concurrency 8 + scaler_worker_manager native tcp://127.0.0.1:8516 --max-task-concurrency 8 The vanilla policy will then automatically scale workers up and down based on the task-to-worker ratio. For a full description of available scaling policies and their parameters, see :doc:`../scaling`. @@ -43,14 +43,28 @@ AWS HPC The :doc:`AWS HPC ` worker manager allows Scaler to offload task execution to cloud environments, currently supporting AWS Batch. It is ideal for bursting workloads to the cloud or utilizing specific hardware not available locally. +ORB (AWS EC2) +~~~~~~~~~~~~~ + +The :doc:`ORB ` worker adapter allows Scaler to dynamically provision workers on AWS EC2 instances. This is ideal for scaling workloads that require significant cloud compute resources or specialized hardware like GPUs. + Common Parameters ~~~~~~~~~~~~~~~~~ All worker managers share a set of :doc:`common configuration parameters ` for networking, worker behavior, and logging. +Unified Entry Points +~~~~~~~~~~~~~~~~~~~~ + +The :doc:`scaler_worker_manager ` command provides a single entry point for all worker manager +adapters. The :doc:`scaler_aio ` command boots the full stack from a single TOML config file. + .. toctree:: :hidden: native + orb aws_hpc/index common_parameters + worker_manager + aio diff --git a/docs/source/tutorials/worker_manager_adapter/native.rst b/docs/source/tutorials/worker_manager_adapter/native.rst index dfacb87a3..b42e22cc4 100644 --- a/docs/source/tutorials/worker_manager_adapter/native.rst +++ b/docs/source/tutorials/worker_manager_adapter/native.rst @@ -6,13 +6,13 @@ The Native worker manager provisions workers as local subprocesses on the same m Getting Started --------------- -To start the Native worker manager, use the ``scaler_worker_manager_baremetal_native`` command. +To start the Native worker manager, use the ``scaler_worker_manager native`` command. Example command: .. code-block:: bash - scaler_worker_manager_baremetal_native tcp://:8516 \ + scaler_worker_manager native tcp://:8516 \ --max-task-concurrency 4 \ --logging-level INFO \ --task-timeout-seconds 60 @@ -21,7 +21,7 @@ Equivalent configuration using a TOML file: .. code-block:: bash - scaler_worker_manager_baremetal_native tcp://:8516 --config config.toml + scaler_worker_manager native tcp://:8516 --config config.toml .. code-block:: toml diff --git a/docs/source/tutorials/worker_manager_adapter/orb.rst b/docs/source/tutorials/worker_manager_adapter/orb.rst new file mode 100644 index 000000000..ab9c333e2 --- /dev/null +++ b/docs/source/tutorials/worker_manager_adapter/orb.rst @@ -0,0 +1,137 @@ +ORB Worker Adapter +================== + +The ORB worker adapter allows Scaler to dynamically provision workers on AWS EC2 instances using the ORB (Open Resource Broker) system. This is particularly useful for scaling workloads that require significant compute resources or specialized hardware available in the cloud. + +This tutorial describes the steps required to get up and running with the ORB adapter. + +Requirements +------------ + +Before using the ORB worker adapter, ensure the following requirements are met on the machine that will run the adapter: + +1. **orb-py and boto3**: The ``orb-py`` and ``boto3`` packages must be installed. These can be installed using the ``orb`` optional dependency of Scaler: + + .. code-block:: bash + + pip install "opengris-scaler[orb]" + +2. **AWS CLI**: The AWS Command Line Interface must be installed and configured with a default profile that has permissions to launch, describe, and terminate EC2 instances. + +3. **Network Connectivity**: The adapter must be able to communicate with AWS APIs and the Scaler scheduler. + +Getting Started +--------------- + +To start the ORB worker adapter, use the ``scaler_worker_manager orb`` command. + +Example command: + +.. code-block:: bash + + scaler_worker_manager orb tcp://:8516 \ + --object-storage-address tcp://:8517 \ + --image-id ami-0528819f94f4f5fa5 \ + --instance-type t3.medium \ + --aws-region us-east-1 \ + --logging-level INFO \ + --task-timeout-seconds 60 + +Equivalent configuration using a TOML file: + +.. code-block:: bash + + scaler_worker_manager orb tcp://:8516 --config config.toml + +.. code-block:: toml + + # config.toml + + [orb_worker_adapter] + object_storage_address = "tcp://:8517" + image_id = "ami-0528819f94f4f5fa5" + instance_type = "t3.medium" + aws_region = "us-east-1" + logging_level = "INFO" + task_timeout_seconds = 60 + +* ``tcp://:8516`` is the address workers will use to connect to the scheduler. +* ``tcp://:8517`` is the address workers will use to connect to the object storage server. +* New workers will be launched using the specified AMI and instance type. + +Networking Configuration +------------------------ + +Workers launched by the ORB adapter are EC2 instances and require an externally-reachable IP address for the scheduler. + +* **Internal Communication**: If the machine running the scheduler is another EC2 instance in the same VPC, you can use EC2 private IP addresses. +* **Public Internet**: If communicating over the public internet, it is highly recommended to set up robust security rules and/or a VPN to protect the cluster. + +Publicly Available AMIs +----------------------- + +We regularly publish publicly available Amazon Machine Images (AMIs) with Python and ``opengris-scaler`` pre-installed. + +.. list-table:: Available Public AMIs + :widths: 15 15 20 20 30 + :header-rows: 1 + + * - Scaler Version + - Python Version + - Amazon Linux 2023 Version + - Date (MM/DD/YYYY) + - AMI ID (us-east-1) + * - 1.14.2 + - 3.13 + - 2023.10.20260120 + - 01/30/2026 + - ``ami-0528819f94f4f5fa5`` + * - 1.15.0 + - 3.13 + - 2023.10.20260302.1 + - 03/16/2026 + - ``ami-044265172bea55d51`` + +New AMIs will be added to this list as they become available. + +Supported Parameters +-------------------- + +.. note:: + For more details on how to configure Scaler, see the :doc:`../configuration` section. + +The ORB worker adapter supports ORB-specific configuration parameters as well as common worker adapter parameters. + +Orb Template Configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* ``--image-id`` (Required): AMI ID for the worker instances. +* ``--instance-type``: EC2 instance type (default: ``t2.micro``). +* ``--aws-region``: AWS region (default: ``us-east-1``). +* ``--key-name``: AWS key pair name for the instances. If not provided, a temporary key pair will be created and deleted on cleanup. +* ``--subnet-id``: AWS subnet ID where the instances will be launched. If not provided, it attempts to discover the default subnet in the default VPC. +* ``--security-group-ids``: Comma-separated list of AWS security group IDs. +* ``--allowed-ip``: IP address to allow in the security group (if created automatically). Defaults to the adapter's external IP. +* ``--orb-config-path``: Path to the ORB root directory (default: ``src/scaler/drivers/orb``). + +Common Parameters +~~~~~~~~~~~~~~~~~ + +For a full list of common parameters including networking, worker configuration, and logging, see :doc:`common_parameters`. + +Cleanup +------- + +The ORB worker adapter is designed to be self-cleaning, but it is important to be aware of the resources it manages: + +* **Key Pairs**: If a ``--key-name`` is not provided, the adapter creates a temporary AWS key pair. +* **Security Groups**: If ``--security-group-ids`` are not provided, the adapter creates a temporary security group to allow communication. +* **Launch Templates**: ORB may additionally create EC2 Launch Templates as part of the machine provisioning process. + +The adapter attempts to delete these temporary resources and terminate all launched EC2 instances when it shuts down gracefully. However, in the event of an ungraceful crash or network failure, some resources may persist in your AWS account. + +.. tip:: + It is recommended to periodically check your AWS console for any orphaned resources (instances, security groups, key pairs, or launch templates) and clean them up manually if necessary to avoid unexpected costs. + +.. warning:: + **Subnet and Security Groups**: Currently, specifying ``--subnet-id`` or ``--security-group-ids`` via configuration might not have the intended effect as the adapter is designed to auto-discover or create these resources. Specifically, the adapter may still attempt to use default subnets or create its own temporary security groups regardless of these parameters. diff --git a/docs/source/tutorials/worker_manager_adapter/worker_manager.rst b/docs/source/tutorials/worker_manager_adapter/worker_manager.rst new file mode 100644 index 000000000..3da1573b1 --- /dev/null +++ b/docs/source/tutorials/worker_manager_adapter/worker_manager.rst @@ -0,0 +1,96 @@ +Unified Worker Manager Entry Point +=================================== + +``scaler_worker_manager`` is a single command that replaces the individual per-adapter entry points +(``scaler_worker_manager_baremetal_native``, ``scaler_worker_manager_symphony``, etc.). + +.. note:: + ``scaler_cluster`` is no longer available. Users should migrate to + ``scaler_worker_manager native --mode fixed``. + +Usage +----- + +.. code-block:: bash + + scaler_worker_manager [options] + +Available sub-commands: ``native``, ``symphony``, ``ecs``, ``hpc``, ``orb``. + +The ``--config`` flag may appear before or after the sub-command name: + +.. code-block:: bash + + scaler_worker_manager native --config cluster.toml tcp://127.0.0.1:8516 + scaler_worker_manager --config cluster.toml native tcp://127.0.0.1:8516 + +Sub-commands +------------ + +native +~~~~~~ + +Provisions workers as local subprocesses. Supports dynamic (default) and fixed-pool mode. + +.. code-block:: bash + + # Dynamic mode (auto-scaling) + scaler_worker_manager native tcp://127.0.0.1:8516 --max-task-concurrency 4 + + # Fixed mode (pre-spawned workers) + scaler_worker_manager native --mode fixed --max-task-concurrency 4 tcp://127.0.0.1:8516 + + # Using a TOML config file + scaler_worker_manager native tcp://127.0.0.1:8516 --config config.toml + +See :doc:`native` for full parameter details. + +symphony +~~~~~~~~ + +Integrates with IBM Spectrum Symphony. + +.. code-block:: bash + + scaler_worker_manager symphony tcp://127.0.0.1:8516 \ + --service-name ScalerService \ + --base-concurrency 4 + +ecs +~~~ + +Provisions workers as AWS ECS Fargate tasks. + +.. code-block:: bash + + scaler_worker_manager ecs tcp://127.0.0.1:8516 \ + --ecs-cluster my-cluster \ + --ecs-task-image my-image:latest \ + --aws-region us-east-1 + +hpc +~~~ + +Provisions workers via AWS Batch. + +.. code-block:: bash + + scaler_worker_manager hpc tcp://127.0.0.1:8516 \ + --job-queue my-queue \ + --job-definition my-job-def \ + --s3-bucket my-bucket + +See :doc:`aws_hpc/index` for full setup details. + +orb +~~~ + +Provisions workers as AWS EC2 instances via ORB. + +.. code-block:: bash + + scaler_worker_manager orb tcp://127.0.0.1:8516 \ + --image-id ami-0528819f94f4f5fa5 \ + --instance-type t3.medium + +See :doc:`orb` for full parameter details. diff --git a/pyproject.toml b/pyproject.toml index e71a98cdf..b88cd4a91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,18 +19,15 @@ classifiers = [ dynamic = ["version"] dependencies = [ "bidict", - "configargparse==1.7.5", - "toml", # required by configargparse "cloudpickle", "psutil==7.2.2", "pycapnp==2.2.2", "pyzmq", "sortedcontainers==2.4.0", "tblib", - "aiohttp", "graphlib-backport; python_version < '3.9'", "typing-extensions>=4.0; python_version < '3.10'", - "tomli; python_version < '3.11'", + "tomli>=2.0; python_version < '3.11'", ] [project.optional-dependencies] @@ -53,10 +50,15 @@ graphblas = [ aws = [ "boto3", ] +orb = [ + "orb-py~=1.3; python_version >= '3.10'", + "boto3; python_version >= '3.10'", +] all = [ "opengris-scaler[aws]", "opengris-scaler[graphblas]", "opengris-scaler[gui]", + "opengris-scaler[orb]", "opengris-scaler[uvloop]", ] @@ -85,14 +87,11 @@ Home = "https://github.com/finos/opengris-scaler" [project.scripts] scaler_scheduler = "scaler.entry_points.scheduler:main" -scaler_cluster = "scaler.entry_points.cluster:main" scaler_top = "scaler.entry_points.top:main" scaler_ui = "scaler.entry_points.webui:main" scaler_object_storage_server = "scaler.entry_points.object_storage_server:main" -scaler_worker_manager_baremetal_native = "scaler.entry_points.worker_manager_baremetal_native:main" -scaler_worker_manager_symphony = "scaler.entry_points.worker_manager_symphony:main" -scaler_worker_manager_aws_raw_ecs = "scaler.entry_points.worker_manager_aws_raw_ecs:main" -scaler_worker_manager_aws_hpc_batch = "scaler.entry_points.worker_manager_aws_hpc_batch:main" +scaler_worker_manager = "scaler.entry_points.worker_manager:main" +scaler_aio = "scaler.entry_points.aio:main" [tool.scikit-build] cmake.source-dir = "." diff --git a/src/run_cluster.py b/src/run_cluster.py deleted file mode 100644 index f53b02a0a..000000000 --- a/src/run_cluster.py +++ /dev/null @@ -1,4 +0,0 @@ -from scaler.entry_points.cluster import main - -if __name__ == "__main__": - main() diff --git a/src/run_worker_manager_symphony.py b/src/run_worker_manager.py similarity index 61% rename from src/run_worker_manager_symphony.py rename to src/run_worker_manager.py index 7c856e174..5f9358643 100644 --- a/src/run_worker_manager_symphony.py +++ b/src/run_worker_manager.py @@ -1,4 +1,4 @@ -from scaler.entry_points.worker_manager_symphony import main +from scaler.entry_points.worker_manager import main from scaler.utility.debug import pdb_wrapped if __name__ == "__main__": diff --git a/src/run_worker_manager_aws_hpc_batch.py b/src/run_worker_manager_aws_hpc_batch.py deleted file mode 100644 index 802bb63cd..000000000 --- a/src/run_worker_manager_aws_hpc_batch.py +++ /dev/null @@ -1,5 +0,0 @@ -from scaler.entry_points.worker_manager_aws_hpc_batch import main -from scaler.utility.debug import pdb_wrapped - -if __name__ == "__main__": - pdb_wrapped(main)() diff --git a/src/run_worker_manager_aws_raw_ecs.py b/src/run_worker_manager_aws_raw_ecs.py deleted file mode 100644 index 6b1699efb..000000000 --- a/src/run_worker_manager_aws_raw_ecs.py +++ /dev/null @@ -1,5 +0,0 @@ -from scaler.entry_points.worker_manager_aws_raw_ecs import main -from scaler.utility.debug import pdb_wrapped - -if __name__ == "__main__": - pdb_wrapped(main)() diff --git a/src/run_worker_manager_baremetal_native.py b/src/run_worker_manager_baremetal_native.py deleted file mode 100644 index 604cf8618..000000000 --- a/src/run_worker_manager_baremetal_native.py +++ /dev/null @@ -1,5 +0,0 @@ -from scaler.entry_points.worker_manager_baremetal_native import main -from scaler.utility.debug import pdb_wrapped - -if __name__ == "__main__": - pdb_wrapped(main)() diff --git a/src/scaler/config/config_class.py b/src/scaler/config/config_class.py index ac23c8f03..ec074bc9e 100644 --- a/src/scaler/config/config_class.py +++ b/src/scaler/config/config_class.py @@ -1,15 +1,10 @@ +import argparse import dataclasses import enum +import os +import sys import typing -from typing import Any, Dict, List, Optional, OrderedDict, Type, TypeVar - -from configargparse import ( - ArgParser, - ArgumentDefaultsHelpFormatter, - ArgumentTypeError, - ConfigFileParser, - TomlConfigParser, -) +from typing import Any, Dict, List, Optional, Type, TypeVar from scaler.config.mixins import ConfigType @@ -52,11 +47,12 @@ class MyConfig(ConfigClass): ## Environment Variables - Any parameter can be configured to read from an environment variable by adding `env="NAME"` to the field metadata. + Any parameter can be configured to read from an environment variable by adding `env_var="NAME"` to the field + metadata. ```python # can be set as --my-field on the command line or in a config file, or using the environment variable `NAME` - my_field: int = dataclasses.field(metadata=dict(env="NAME")) + my_field: int = dataclasses.field(metadata=dict(env_var="NAME")) ``` ## Precedence @@ -123,6 +119,18 @@ class MyConfig(ConfigClass): field_two: int = dataclasses.field(metadata=dict(positional=True)) ``` + ## Sub-commands + + A field with `subcommand=""` in its metadata declares a sub-command: + - **Field name** -> CLI sub-command name + - **`subcommand` value** -> TOML section to read when this sub-command is active + - **Field type** -> `Optional[SomeConfigClass]` with `default=None` + + ## Section Fields + + A field with `section=""` in its metadata is populated from the TOML file only + (no CLI argument). The field type may be `Optional[SomeConfigClass]` or `List[SomeConfigClass]`. + ## Composition Config classes can be composed. If a config class has fields that are config classes, @@ -192,49 +200,43 @@ class MyConfigType(ConfigType): """ @classmethod - def configure_parser(cls: type, parser: ArgParser): - fields = dataclasses.fields(cls) + def configure_parser(cls: type, parser: argparse.ArgumentParser) -> None: + for field in dataclasses.fields(cls): # type: ignore[arg-type] + if "subcommand" in field.metadata or "section" in field.metadata: + continue # handled by parse() - for field in fields: if is_config_class(field.type): field.type.configure_parser(parser) # type: ignore[union-attr] continue - kwargs = dict(field.metadata) + # Strip keys that argparse doesn't understand + kwargs = { + k: v + for k, v in field.metadata.items() + if k not in ("env_var", "positional", "long", "short", "name", "subcommand", "section") + } - # usually command line options use hyphens instead of underscores - - if kwargs.pop("positional", False): - args = [kwargs.pop("name", field.name)] + if field.metadata.get("positional", False): + args = [field.metadata.get("name", field.name)] else: - long_name = kwargs.pop("long", f"--{field.name.replace('_', '-')}") - if "short" in kwargs: - args = [long_name, kwargs.pop("short")] - else: - args = [long_name] - - # this sets the key given back when args are parsed + long_name = field.metadata.get("long", f"--{field.name.replace('_', '-')}") + args = [long_name, field.metadata["short"]] if "short" in field.metadata else [long_name] kwargs["dest"] = field.name if "default" in kwargs: raise TypeError("'default' cannot be provided in field metadata") - if field.default != dataclasses.MISSING: + if field.default is not dataclasses.MISSING: kwargs["default"] = field.default - if field.default_factory != dataclasses.MISSING: + if field.default_factory is not dataclasses.MISSING: # type: ignore[misc] kwargs["default"] = field.default_factory() - # when store true or store false is set, setting the type raises a type error + # when store_true or store_false is set, setting the type raises a type error if kwargs.get("action") not in ("store_true", "store_false"): - # sometimes the user will set the type manually - # this is required for types such as `Option[T]`, where they cannt be directly constructed from a string if "type" not in kwargs: - opts = get_type_args(field.type) - - # set all of the options, except where already set - for key, value in opts.items(): + for key, value in get_type_args(field.type).items(): if key not in kwargs: kwargs[key] = value @@ -242,73 +244,293 @@ def configure_parser(cls: type, parser: ArgParser): @classmethod def parse(cls: Type[T], program_name: str, section: str) -> T: - parser = ArgParser( - program_name, - formatter_class=ArgumentDefaultsHelpFormatter, - config_file_parser_class=UnderscoreTomlConfigParser.with_sections([section]), - ) + subcommand_fields = [ + field for field in dataclasses.fields(cls) if "subcommand" in field.metadata # type: ignore[arg-type] + ] - parser.add_argument("--config", "-c", is_config_file=True, help="Path to the TOML configuration file.") + if subcommand_fields: + # Root parser: routes to sub-commands. + parser = argparse.ArgumentParser(prog=program_name, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--config", "-c", metavar="FILE", help="Path to the TOML configuration file.") + + # Pre-scan for --config before building the subparser tree so TOML + # defaults can be injected into each subparser during construction. + config_path = _find_config_arg(sys.argv) + toml_data = _load_toml(config_path) if config_path else {} + + _build_subparser_tree(cls, parser, parent_cls=cls, sections=[section], dest="_sub", toml_data=toml_data) + + kwargs = vars(parser.parse_args()) + kwargs.pop("config", None) + + return _reconstruct_recursive(cls, kwargs, dest="_sub") + + # Normal path. + parser = argparse.ArgumentParser(prog=program_name, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--config", "-c", metavar="FILE", help="Path to the TOML configuration file.") cls.configure_parser(parser) + # Pass 1: locate --config without failing on unrecognised args. + config_path_str: Optional[str] = vars(parser.parse_known_args()[0]).get("config") + + # Load TOML and inject section values as defaults (below CLI, above hardcoded). + toml_data: Dict[str, Any] = {} # type: ignore[no-redef] + if config_path_str: + toml_data = _load_toml(config_path_str) + toml_defs = _toml_section_defaults(toml_data.get(section, {}), cls) + if toml_defs: + parser.set_defaults(**toml_defs) + + # Env-var defaults override TOML but are overridden by CLI. + env_defs = _env_defaults(cls) + if env_defs: + parser.set_defaults(**env_defs) + + # Pass 2: full parse — CLI wins. kwargs = vars(parser.parse_args()) + kwargs.pop("config", None) - # remove this from the args - kwargs.pop("config") + # Populate section= fields from the raw TOML (not from parsed args). + section_fields = [f for f in dataclasses.fields(cls) if "section" in f.metadata] # type: ignore[arg-type] + full_toml = toml_data if section_fields else None - # we need to manually handle any ConfigClass fields - for field in dataclasses.fields(cls): # type: ignore[arg-type] - if is_config_class(field.type): + return _reconstruct_config(cls, kwargs, full_toml) - # steal arguments for the config class - inner_kwargs = {} - for f in dataclasses.fields(field.type): # type: ignore[arg-type] - if f.name in kwargs: - inner_kwargs[f.name] = kwargs.pop(f.name) - # instantiate and update the args - kwargs[field.name] = field.type(**inner_kwargs) # type: ignore[operator] +# --------------------------------------------------------------------------- +# Module-level helpers: TOML loading and default injection +# --------------------------------------------------------------------------- - return cls(**kwargs) +def _find_config_arg(argv: List[str]) -> Optional[str]: + """Pre-scan argv for --config/-c without invoking the full parser.""" + for i, arg in enumerate(argv): + if arg in ("--config", "-c") and i + 1 < len(argv): + return argv[i + 1] + if arg.startswith("--config="): + return arg.split("=", 1)[1] + return None -class UnderscoreTomlConfigParser(ConfigFileParser): - """A TOML config parser that converts underscores to hyphens in key names.""" - _sections: Optional[List[str]] = None +def _load_toml(config_path: str) -> Dict[str, Any]: + """Parse a TOML file and return its full contents as a dict.""" + try: + import tomllib # type: ignore[import] # Python 3.11+ + except ImportError: + import tomli as tomllib # type: ignore[no-redef] + with open(config_path) as f: + return tomllib.loads(f.read()) - def __init__(self): - self._parser = TomlConfigParser(sections=self._sections) - @classmethod - def with_sections(cls, sections: List[str]) -> type: - """Return a subclass with the given sections baked in.""" - return type(cls.__name__, (cls,), {"_sections": sections}) +def _toml_section_defaults(section_data: Dict[str, Any], cls: type) -> Dict[str, Any]: + """Flatten a raw TOML section dict into argparse-compatible defaults. + + Normalises keys so that both 'my-field' and 'my_field' map to 'my_field'. + Only returns keys that correspond to actual fields on cls (and its nested + ConfigClass fields), so unrelated TOML keys are silently ignored. + """ + field_names: set = set() + for f in dataclasses.fields(cls): # type: ignore[arg-type] + if is_config_class(f.type): + for ff in dataclasses.fields(f.type): # type: ignore[arg-type] + field_names.add(ff.name) + elif "subcommand" not in f.metadata and "section" not in f.metadata: + field_names.add(f.name) + + return {key.replace("-", "_"): value for key, value in section_data.items() if key.replace("-", "_") in field_names} + + +def _env_defaults(cls: type) -> Dict[str, Any]: + """Collect env-var values for fields that declare env_var= metadata. + + Applies the same type coercion that argparse would use for CLI values, + so the value stored as a default is already the correct Python type. + """ + result: Dict[str, Any] = {} + for field in dataclasses.fields(cls): # type: ignore[arg-type] + if is_config_class(field.type): + result.update(_env_defaults(field.type)) # type: ignore[arg-type] + continue + env_name = field.metadata.get("env_var") + if env_name and env_name in os.environ: + raw = os.environ[env_name] + type_func = field.metadata.get("type") or get_type_args(field.type).get("type") + result[field.name] = type_func(raw) if type_func else raw + return result + + +# --------------------------------------------------------------------------- +# Module-level helpers: reconstruction from TOML / parsed-args dicts +# --------------------------------------------------------------------------- + + +def _reconstruct_from_toml(config_cls: Type[T], data: Dict[str, Any]) -> T: + """Build a ConfigClass instance from a raw TOML section dict.""" + result_kwargs: Dict[str, Any] = {} + for field in dataclasses.fields(config_cls): # type: ignore[arg-type] + if field.name not in data: + continue + if is_config_class(field.type): + result_kwargs[field.name] = _reconstruct_from_toml(field.type, data[field.name]) # type: ignore[arg-type] + else: + result_kwargs[field.name] = data[field.name] + return config_cls(**result_kwargs) + + +def _reconstruct_config(config_cls: Type[T], kwargs: Dict[str, Any], toml_data: Optional[Dict[str, Any]] = None) -> T: + """Build a ConfigClass instance from a flat parsed-args dict. + + Pops only the fields belonging to config_cls from kwargs. + Any remaining entries in kwargs are left for the caller. + + toml_data is the full parsed TOML dict; required when config_cls has + fields with section= metadata. + """ + result_kwargs: Dict[str, Any] = {} + for field in dataclasses.fields(config_cls): # type: ignore[arg-type] + if "section" in field.metadata: + section_name = field.metadata["section"] + raw = (toml_data or {}).get(section_name) + inner_cls = ( + get_list_type(field.type) + if is_list(field.type) + else get_optional_type(field.type) if is_optional(field.type) else field.type + ) + if raw is None: + result_kwargs[field.name] = [] if is_list(field.type) else None + elif isinstance(raw, dict): + instance = _reconstruct_from_toml(inner_cls, raw) # type: ignore[arg-type] + result_kwargs[field.name] = [instance] if is_list(field.type) else instance + else: # list of dicts — TOML [[array of tables]] + result_kwargs[field.name] = [ + _reconstruct_from_toml(inner_cls, item) # type: ignore[arg-type] + for item in raw + ] + elif is_config_class(field.type): + inner_kwargs: Dict[str, Any] = {} + for f in dataclasses.fields(field.type): # type: ignore[arg-type] + if f.name in kwargs: + inner_kwargs[f.name] = kwargs.pop(f.name) + result_kwargs[field.name] = field.type(**inner_kwargs) # type: ignore[operator] + elif field.name in kwargs: + result_kwargs[field.name] = kwargs.pop(field.name) + return config_cls(**result_kwargs) + + +# --------------------------------------------------------------------------- +# Module-level helpers: subparser tree construction and recursive reconstruction +# --------------------------------------------------------------------------- + + +def _build_subparser_tree( + cls: type, + parser: argparse.ArgumentParser, + parent_cls: type, + sections: List[str], + dest: str, + toml_data: Dict[str, Any], +) -> None: + """Recursively add subparsers to *parser* for every subcommand field in *cls*. + + Args: + cls: The config class whose subcommand fields are being registered. + parser: The (sub)parser to attach the new subparsers to. + parent_cls: The config class that owns *parser*; its non-subcommand fields + are registered on each subparser so they are available at every level. + sections: Accumulated TOML section names from all ancestor levels. + dest: Unique argparse dest name for this level's chosen subcommand. + toml_data: Full parsed TOML dict loaded from --config (may be empty). + """ + subcommand_fields = [f for f in dataclasses.fields(cls) if "subcommand" in f.metadata] # type: ignore[arg-type] + if not subcommand_fields: + return + + subparsers = parser.add_subparsers(dest=dest, required=True) + + for field in subcommand_fields: + sub_section = field.metadata["subcommand"] + config_cls = get_optional_type(field.type) if is_optional(field.type) else field.type + level_sections = [s for s in sections + [sub_section] if s] + + subparser = subparsers.add_parser(field.name, formatter_class=argparse.ArgumentDefaultsHelpFormatter) + subparser.add_argument("--config", "-c", metavar="FILE", help="Path to the TOML configuration file.") + parent_cls.configure_parser(subparser) # type: ignore[attr-defined] # parent-level fields + config_cls.configure_parser(subparser) # type: ignore[union-attr] # this sub-command's own fields + + # Inject TOML defaults: merge all ancestor sections + this sub-command's section. + combined: Dict[str, Any] = {} + for s in level_sections: + combined.update(toml_data.get(s, {})) + toml_defs = _toml_section_defaults(combined, config_cls) # type: ignore[arg-type] + if toml_defs: + subparser.set_defaults(**toml_defs) + + # Env-var defaults override TOML. + env_defs = _env_defaults(config_cls) # type: ignore[arg-type] + if env_defs: + subparser.set_defaults(**env_defs) + + _build_subparser_tree( + config_cls, # type: ignore[arg-type] + subparser, + config_cls, # type: ignore[arg-type] + sections=level_sections, + dest=f"{dest}_{field.name}", + toml_data=toml_data, + ) - def parse(self, stream) -> OrderedDict[str, Any]: - return OrderedDict((k.replace("_", "-"), v) for k, v in self._parser.parse(stream).items()) - def get_syntax_description(self) -> str: - return self._parser.get_syntax_description() +def _reconstruct_recursive(cls: Type[T], kwargs: Dict[str, Any], dest: str) -> T: + """Recursively reconstruct a ConfigClass from a flat parsed-args dict.""" + subcommand_fields = [f for f in dataclasses.fields(cls) if "subcommand" in f.metadata] # type: ignore[arg-type] + + if not subcommand_fields: + return _reconstruct_config(cls, kwargs) + + subcommand = kwargs.pop(dest) + sub_field = next(f for f in subcommand_fields if f.name == subcommand) + config_cls = get_optional_type(sub_field.type) if is_optional(sub_field.type) else sub_field.type + + selected_config = _reconstruct_recursive(config_cls, kwargs, f"{dest}_{subcommand}") # type: ignore[arg-type] + + wrapper_kwargs: Dict[str, Any] = {f.name: None for f in subcommand_fields} + wrapper_kwargs[subcommand] = selected_config + + for field in dataclasses.fields(cls): # type: ignore[arg-type] + if "subcommand" in field.metadata: + continue + if is_config_class(field.type): + inner_kwargs: Dict[str, Any] = {} + for f in dataclasses.fields(field.type): # type: ignore[arg-type] + if f.name in kwargs: + inner_kwargs[f.name] = kwargs.pop(f.name) + wrapper_kwargs[field.name] = field.type(**inner_kwargs) # type: ignore[operator] + elif field.name in kwargs: + wrapper_kwargs[field.name] = kwargs.pop(field.name) + + return cls(**wrapper_kwargs) + + +# --------------------------------------------------------------------------- +# Type-parsing helpers +# --------------------------------------------------------------------------- def parse_bool(s: str) -> bool: """parse a bool from a conventional string representation""" - lower = s.lower() if lower == "true": return True if lower == "false": return False - - raise ArgumentTypeError(f"'{s}' is not a valid bool") + raise argparse.ArgumentTypeError(f"'{s}' is not a valid bool") def parse_enum(s: str, enumm: Type[enum.Enum]) -> Any: try: return enumm[s] except KeyError as e: - raise ArgumentTypeError(f"'{s}' is not a valid {enumm.__name__}") from e + raise argparse.ArgumentTypeError(f"'{s}' is not a valid {enumm.__name__}") from e def is_optional(ty: Any) -> bool: diff --git a/src/scaler/config/section/orb_worker_adapter.py b/src/scaler/config/section/orb_worker_adapter.py new file mode 100644 index 000000000..a0f06ee34 --- /dev/null +++ b/src/scaler/config/section/orb_worker_adapter.py @@ -0,0 +1,50 @@ +import dataclasses +from typing import List, Optional + +from scaler.config import defaults +from scaler.config.common.logging import LoggingConfig +from scaler.config.common.worker import WorkerConfig +from scaler.config.common.worker_manager import WorkerManagerConfig +from scaler.config.config_class import ConfigClass +from scaler.utility.event_loop import EventLoopType + + +@dataclasses.dataclass +class ORBWorkerAdapterConfig(ConfigClass): + """Configuration for the ORB worker adapter.""" + + worker_manager_config: WorkerManagerConfig + + # ORB Template configuration + image_id: str = dataclasses.field(metadata=dict(help="AMI ID for the worker instances", required=True)) + key_name: Optional[str] = dataclasses.field( + default=None, metadata=dict(help="AWS key pair name for the instances (optional)") + ) + subnet_id: Optional[str] = dataclasses.field( + default=None, metadata=dict(help="AWS subnet ID where the instances will be launched (optional)") + ) + + worker_config: WorkerConfig = dataclasses.field(default_factory=WorkerConfig) + logging_config: LoggingConfig = dataclasses.field(default_factory=LoggingConfig) + event_loop: str = dataclasses.field( + default="builtin", + metadata=dict(short="-el", choices=EventLoopType.allowed_types(), help="select the event loop type"), + ) + + worker_io_threads: int = dataclasses.field( + default=defaults.DEFAULT_IO_THREADS, + metadata=dict(short="-wit", help="set the number of io threads for io backend per worker"), + ) + + instance_type: str = dataclasses.field(default="t2.micro", metadata=dict(help="EC2 instance type")) + aws_region: Optional[str] = dataclasses.field(default="us-east-1", metadata=dict(help="AWS region")) + security_group_ids: List[str] = dataclasses.field( + default_factory=list, + metadata=dict( + type=lambda s: [x for x in s.split(",") if x], help="Comma-separated list of AWS security group IDs" + ), + ) + + def __post_init__(self) -> None: + if self.worker_io_threads <= 0: + raise ValueError("worker_io_threads must be a positive integer.") diff --git a/src/scaler/entry_points/aio.py b/src/scaler/entry_points/aio.py new file mode 100644 index 000000000..b365743be --- /dev/null +++ b/src/scaler/entry_points/aio.py @@ -0,0 +1,113 @@ +import dataclasses +import multiprocessing +import sys +from typing import List, Optional + +from scaler.config.config_class import ConfigClass +from scaler.config.section.aws_hpc_worker_manager import AWSBatchWorkerManagerConfig +from scaler.config.section.ecs_worker_manager import ECSWorkerManagerConfig +from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig +from scaler.config.section.orb_worker_adapter import ORBWorkerAdapterConfig +from scaler.config.section.scheduler import SchedulerConfig +from scaler.config.section.symphony_worker_manager import SymphonyWorkerManagerConfig +from scaler.utility.logging.utility import setup_logger + + +@dataclasses.dataclass +class _AIOConfig(ConfigClass): + # Declaration order = startup order (scheduler before workers). + scheduler: Optional[SchedulerConfig] = dataclasses.field(default=None, metadata=dict(section="scheduler")) + native_worker_manager: List[NativeWorkerManagerConfig] = dataclasses.field( + default_factory=list, metadata=dict(section="native_worker_manager") + ) + symphony_worker_manager: List[SymphonyWorkerManagerConfig] = dataclasses.field( + default_factory=list, metadata=dict(section="symphony_worker_manager") + ) + ecs_worker_manager: List[ECSWorkerManagerConfig] = dataclasses.field( + default_factory=list, metadata=dict(section="ecs_worker_manager") + ) + aws_hpc_worker_manager: List[AWSBatchWorkerManagerConfig] = dataclasses.field( + default_factory=list, metadata=dict(section="aws_hpc_worker_manager") + ) + orb_worker_adapter: List[ORBWorkerAdapterConfig] = dataclasses.field( + default_factory=list, metadata=dict(section="orb_worker_adapter") + ) + + +# --- per-type runners (module-level for multiprocessing spawn compatibility) --- + + +def _run_scheduler(config: SchedulerConfig) -> None: + from scaler.entry_points.scheduler import main as _main + + _main(config) + + +def _run_native(config: NativeWorkerManagerConfig) -> None: + from scaler.worker_manager_adapter.baremetal.native import NativeWorkerManager + + setup_logger(config.logging_config.paths, config.logging_config.config_file, config.logging_config.level) + NativeWorkerManager(config).run() + + +def _run_symphony(config: SymphonyWorkerManagerConfig) -> None: + from scaler.worker_manager_adapter.symphony.worker_manager import SymphonyWorkerManager + + setup_logger(config.logging_config.paths, config.logging_config.config_file, config.logging_config.level) + SymphonyWorkerManager(config).run() + + +def _run_ecs(config: ECSWorkerManagerConfig) -> None: + from scaler.worker_manager_adapter.aws_raw.ecs import ECSWorkerManager + + setup_logger(config.logging_config.paths, config.logging_config.config_file, config.logging_config.level) + ECSWorkerManager(config).run() + + +def _run_hpc(config: AWSBatchWorkerManagerConfig) -> None: + from scaler.worker_manager_adapter.aws_hpc.worker_manager import AWSHPCWorkerManager + + setup_logger(config.logging_config.paths, config.logging_config.config_file, config.logging_config.level) + AWSHPCWorkerManager(config).run() + + +def _run_orb(config: ORBWorkerAdapterConfig) -> None: + from scaler.worker_manager_adapter.orb.worker_manager import ORBWorkerAdapter + + setup_logger(config.logging_config.paths, config.logging_config.config_file, config.logging_config.level) + ORBWorkerAdapter(config).run() + + +def main() -> None: + config = _AIOConfig.parse("scaler_aio", "aio") + + processes: List[multiprocessing.Process] = [] + + if config.scheduler is not None: + processes.append(multiprocessing.Process(target=_run_scheduler, args=(config.scheduler,), name="scheduler")) + for native_cfg in config.native_worker_manager: + processes.append(multiprocessing.Process(target=_run_native, args=(native_cfg,), name="native_worker_manager")) + for symphony_cfg in config.symphony_worker_manager: + processes.append( + multiprocessing.Process(target=_run_symphony, args=(symphony_cfg,), name="symphony_worker_manager") + ) + for ecs_cfg in config.ecs_worker_manager: + processes.append(multiprocessing.Process(target=_run_ecs, args=(ecs_cfg,), name="ecs_worker_manager")) + for hpc_cfg in config.aws_hpc_worker_manager: + processes.append(multiprocessing.Process(target=_run_hpc, args=(hpc_cfg,), name="aws_hpc_worker_manager")) + for orb_cfg in config.orb_worker_adapter: + processes.append(multiprocessing.Process(target=_run_orb, args=(orb_cfg,), name="orb_worker_adapter")) + + if not processes: + print("scaler_aio: no recognized sections found in config file", file=sys.stderr) + sys.exit(1) + + for process in processes: + process.start() + + for process in processes: + process.join() + + +if __name__ == "__main__": + main() diff --git a/src/scaler/entry_points/cluster.py b/src/scaler/entry_points/cluster.py deleted file mode 100644 index 849d135e7..000000000 --- a/src/scaler/entry_points/cluster.py +++ /dev/null @@ -1,13 +0,0 @@ -import dataclasses - -from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig, NativeWorkerManagerMode -from scaler.worker_manager_adapter.baremetal.native import NativeWorkerManager - - -def main() -> None: - config = NativeWorkerManagerConfig.parse("Scaler Cluster", "cluster") - config = dataclasses.replace(config, mode=NativeWorkerManagerMode.FIXED) - NativeWorkerManager(config).run() - - -__all__ = ["main"] diff --git a/src/scaler/entry_points/scheduler.py b/src/scaler/entry_points/scheduler.py index 129bbe8bd..4b3164af9 100644 --- a/src/scaler/entry_points/scheduler.py +++ b/src/scaler/entry_points/scheduler.py @@ -1,11 +1,14 @@ +from typing import Optional + from scaler.cluster.object_storage_server import ObjectStorageServerProcess from scaler.cluster.scheduler import SchedulerProcess from scaler.config.section.scheduler import SchedulerConfig from scaler.config.types.object_storage_server import ObjectStorageAddressConfig -def main(): - scheduler_config = SchedulerConfig.parse("Scaler Scheduler", "scheduler") +def main(scheduler_config: Optional[SchedulerConfig] = None) -> None: + if scheduler_config is None: + scheduler_config = SchedulerConfig.parse("Scaler Scheduler", "scheduler") object_storage_address = scheduler_config.object_storage_address object_storage = None diff --git a/src/scaler/entry_points/worker_manager.py b/src/scaler/entry_points/worker_manager.py new file mode 100644 index 000000000..614f0d453 --- /dev/null +++ b/src/scaler/entry_points/worker_manager.py @@ -0,0 +1,77 @@ +import dataclasses +from typing import Optional + +from scaler.config.config_class import ConfigClass +from scaler.config.section.aws_hpc_worker_manager import AWSBatchWorkerManagerConfig +from scaler.config.section.ecs_worker_manager import ECSWorkerManagerConfig +from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig +from scaler.config.section.orb_worker_adapter import ORBWorkerAdapterConfig +from scaler.config.section.symphony_worker_manager import SymphonyWorkerManagerConfig +from scaler.utility.logging.utility import setup_logger + + +@dataclasses.dataclass +class _WorkerManagerConfig(ConfigClass): + native: Optional[NativeWorkerManagerConfig] = dataclasses.field( + default=None, metadata=dict(subcommand="native_worker_manager") + ) + symphony: Optional[SymphonyWorkerManagerConfig] = dataclasses.field( + default=None, metadata=dict(subcommand="symphony_worker_manager") + ) + ecs: Optional[ECSWorkerManagerConfig] = dataclasses.field( + default=None, metadata=dict(subcommand="ecs_worker_manager") + ) + hpc: Optional[AWSBatchWorkerManagerConfig] = dataclasses.field( + default=None, metadata=dict(subcommand="aws_hpc_worker_manager") + ) + orb: Optional[ORBWorkerAdapterConfig] = dataclasses.field( + default=None, metadata=dict(subcommand="orb_worker_adapter") + ) + + +def main() -> None: + config = _WorkerManagerConfig.parse("scaler_worker_manager", "") + + if config.native is not None: + from scaler.worker_manager_adapter.baremetal.native import NativeWorkerManager + + setup_logger( + config.native.logging_config.paths, + config.native.logging_config.config_file, + config.native.logging_config.level, + ) + NativeWorkerManager(config.native).run() + elif config.symphony is not None: + from scaler.worker_manager_adapter.symphony.worker_manager import SymphonyWorkerManager + + setup_logger( + config.symphony.logging_config.paths, + config.symphony.logging_config.config_file, + config.symphony.logging_config.level, + ) + SymphonyWorkerManager(config.symphony).run() + elif config.ecs is not None: + from scaler.worker_manager_adapter.aws_raw.ecs import ECSWorkerManager + + setup_logger( + config.ecs.logging_config.paths, config.ecs.logging_config.config_file, config.ecs.logging_config.level + ) + ECSWorkerManager(config.ecs).run() + elif config.hpc is not None: + from scaler.worker_manager_adapter.aws_hpc.worker_manager import AWSHPCWorkerManager + + setup_logger( + config.hpc.logging_config.paths, config.hpc.logging_config.config_file, config.hpc.logging_config.level + ) + AWSHPCWorkerManager(config.hpc).run() + elif config.orb is not None: + from scaler.worker_manager_adapter.orb.worker_manager import ORBWorkerAdapter + + setup_logger( + config.orb.logging_config.paths, config.orb.logging_config.config_file, config.orb.logging_config.level + ) + ORBWorkerAdapter(config.orb).run() + + +if __name__ == "__main__": + main() diff --git a/src/scaler/entry_points/worker_manager_aws_hpc_batch.py b/src/scaler/entry_points/worker_manager_aws_hpc_batch.py deleted file mode 100644 index db6008380..000000000 --- a/src/scaler/entry_points/worker_manager_aws_hpc_batch.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Entry point for AWS HPC Worker Manager. - -Supports multiple AWS HPC backends: -- batch: AWS Batch (EC2 compute environment) -- (future) parallelcluster: AWS ParallelCluster -- (future) lambda: AWS Lambda -""" - -import logging -import multiprocessing - -from scaler.config.section.aws_hpc_worker_manager import AWSBatchWorkerManagerConfig, AWSHPCBackend -from scaler.utility.logging.utility import setup_logger - - -def _create_batch_worker(config: AWSBatchWorkerManagerConfig) -> multiprocessing.Process: - from scaler.worker_manager_adapter.aws_hpc.worker import AWSBatchWorker - - logging.info(f" Job Queue: {config.job_queue}") - logging.info(f" Job Definition: {config.job_definition}") - logging.info(f" S3: s3://{config.s3_bucket}/{config.s3_prefix}") - logging.info(f" Max Concurrent Jobs: {config.max_concurrent_jobs}") - logging.info(f" Job Timeout: {config.job_timeout_minutes} minutes") - - return AWSBatchWorker( - name=config.name or "aws-batch-worker", - address=config.worker_manager_config.scheduler_address, - object_storage_address=config.worker_manager_config.object_storage_address, - job_queue=config.job_queue, - job_definition=config.job_definition, - aws_region=config.aws_region, - s3_bucket=config.s3_bucket, - s3_prefix=config.s3_prefix, - base_concurrency=config.max_concurrent_jobs, - heartbeat_interval_seconds=config.heartbeat_interval_seconds, - death_timeout_seconds=config.death_timeout_seconds, - task_queue_size=config.task_queue_size, - io_threads=config.worker_io_threads, - event_loop=config.event_loop, - job_timeout_seconds=config.job_timeout_minutes * 60, - worker_manager_id=config.worker_manager_id.encode(), - ) - - -def main(): - config = AWSBatchWorkerManagerConfig.parse("Scaler AWS HPC Worker Manager", "aws_hpc_worker_manager") - - setup_logger(config.logging_config.paths, config.logging_config.config_file, config.logging_config.level) - - logging.info(f"Starting AWS HPC Worker Manager (backend: {config.backend.name})") - logging.info(f" Scheduler: {config.worker_manager_config.scheduler_address}") - - if config.backend == AWSHPCBackend.batch: - worker = _create_batch_worker(config) - else: - raise NotImplementedError(f"backend {config.backend.name!r} is not yet implemented") - - worker.start() - worker.join() - - -if __name__ == "__main__": - main() diff --git a/src/scaler/entry_points/worker_manager_aws_raw_ecs.py b/src/scaler/entry_points/worker_manager_aws_raw_ecs.py deleted file mode 100644 index f296d0010..000000000 --- a/src/scaler/entry_points/worker_manager_aws_raw_ecs.py +++ /dev/null @@ -1,18 +0,0 @@ -from scaler.config.section.ecs_worker_manager import ECSWorkerManagerConfig -from scaler.utility.logging.utility import setup_logger -from scaler.worker_manager_adapter.aws_raw.ecs import ECSWorkerManager - - -def main(): - ecs_config = ECSWorkerManagerConfig.parse("Scaler ECS Worker Manager", "ecs_worker_manager") - - setup_logger( - ecs_config.logging_config.paths, ecs_config.logging_config.config_file, ecs_config.logging_config.level - ) - - ecs_worker_manager = ECSWorkerManager(ecs_config) - ecs_worker_manager.run() - - -if __name__ == "__main__": - main() diff --git a/src/scaler/entry_points/worker_manager_baremetal_native.py b/src/scaler/entry_points/worker_manager_baremetal_native.py deleted file mode 100644 index cbcbec6ef..000000000 --- a/src/scaler/entry_points/worker_manager_baremetal_native.py +++ /dev/null @@ -1,21 +0,0 @@ -from scaler.config.section.native_worker_manager import NativeWorkerManagerConfig -from scaler.utility.logging.utility import setup_logger -from scaler.worker_manager_adapter.baremetal.native import NativeWorkerManager - - -def main(): - native_manager_config = NativeWorkerManagerConfig.parse("Scaler Native Worker Manager", "native_worker_manager") - - setup_logger( - native_manager_config.logging_config.paths, - native_manager_config.logging_config.config_file, - native_manager_config.logging_config.level, - ) - - native_worker_manager = NativeWorkerManager(native_manager_config) - - native_worker_manager.run() - - -if __name__ == "__main__": - main() diff --git a/src/scaler/entry_points/worker_manager_symphony.py b/src/scaler/entry_points/worker_manager_symphony.py deleted file mode 100644 index 2d4913a39..000000000 --- a/src/scaler/entry_points/worker_manager_symphony.py +++ /dev/null @@ -1,19 +0,0 @@ -from scaler.config.section.symphony_worker_manager import SymphonyWorkerManagerConfig -from scaler.utility.logging.utility import setup_logger -from scaler.worker_manager_adapter.symphony.worker_manager import SymphonyWorkerManager - - -def main(): - symphony_config = SymphonyWorkerManagerConfig.parse("Scaler Symphony Worker Manager", "symphony_worker_manager") - setup_logger( - symphony_config.logging_config.paths, - symphony_config.logging_config.config_file, - symphony_config.logging_config.level, - ) - - symphony_worker_manager = SymphonyWorkerManager(symphony_config) - symphony_worker_manager.run() - - -if __name__ == "__main__": - main() diff --git a/src/scaler/io/utility.py b/src/scaler/io/utility.py index d913ae7dc..fbe60aabc 100644 --- a/src/scaler/io/utility.py +++ b/src/scaler/io/utility.py @@ -32,6 +32,7 @@ def create_async_simple_context(): return zmq.asyncio.Context() elif type == NetworkBackend.ymq: from scaler.io.ymq import IOContext + return IOContext() raise ValueError("Unknown network backend") diff --git a/src/scaler/io/ymq/__init__.py b/src/scaler/io/ymq/__init__.py index dc8bb1541..b658f0caf 100644 --- a/src/scaler/io/ymq/__init__.py +++ b/src/scaler/io/ymq/__init__.py @@ -7,7 +7,6 @@ "ErrorCode", "IOContext", "Message", - # Exception types "YMQException", "ConnectorSocketClosedByRemoteEndError", @@ -18,21 +17,19 @@ "SysCallError", ] -from scaler.io.ymq._ymq import ( +from scaler.io.ymq._ymq import ( # Exception types Address, AddressType, Bytes, - ErrorCode, - IOContext, - Message, - - # Exception types - YMQException, ConnectorSocketClosedByRemoteEndError, + ErrorCode, InvalidAddressFormatError, InvalidPortFormatError, + IOContext, + Message, RemoteEndDisconnectedOnSocketWithoutGuaranteedDeliveryError, SocketStopRequestedError, SysCallError, + YMQException, ) from scaler.io.ymq.sockets import BinderSocket, ConnectorSocket diff --git a/src/scaler/io/ymq/utils.py b/src/scaler/io/ymq/utils.py index d125cff1a..14387ddf5 100644 --- a/src/scaler/io/ymq/utils.py +++ b/src/scaler/io/ymq/utils.py @@ -5,7 +5,7 @@ try: from typing import Concatenate, ParamSpec # type: ignore[attr-defined] except ImportError: - from typing_extensions import ParamSpec, Concatenate # type: ignore[assignment] + from typing_extensions import Concatenate, ParamSpec # type: ignore[assignment] P = ParamSpec("P") diff --git a/src/scaler/scheduler/controllers/worker_manager_controller.py b/src/scaler/scheduler/controllers/worker_manager_controller.py index 758f2673b..053a43d51 100644 --- a/src/scaler/scheduler/controllers/worker_manager_controller.py +++ b/src/scaler/scheduler/controllers/worker_manager_controller.py @@ -74,6 +74,12 @@ async def on_heartbeat(self, source: bytes, heartbeat: WorkerManagerHeartbeat): # Build cross-manager snapshots from all known managers worker_manager_snapshots = self._build_manager_snapshots() + # Wait for the previous command to complete before sending another. + # Worker managers can take a long time to fulfill commands (e.g. ORB polls for instance IDs), + # so sending a new command before the response arrives causes duplicate work and errors. + if source in self._pending_commands: + return + commands = self._policy_controller.get_scaling_commands( information_snapshot, heartbeat, managed_worker_ids, managed_worker_capabilities, worker_manager_snapshots ) diff --git a/src/scaler/utility/dict_utils.py b/src/scaler/utility/dict_utils.py new file mode 100644 index 000000000..97a0b8429 --- /dev/null +++ b/src/scaler/utility/dict_utils.py @@ -0,0 +1,38 @@ +import re +from typing import Any + + +def to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def to_snake_case(camel_str: str) -> str: + pattern = re.compile(r"(? Any: + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + new_key = to_camel_case(k) if isinstance(k, str) else k + new_d[new_key] = camelcase_dict(v) + return new_d + elif isinstance(d, list): + return [camelcase_dict(i) for i in d] + else: + return d + + +def snakecase_dict(d: Any) -> Any: + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + new_key = to_snake_case(k) if isinstance(k, str) else k + new_d[new_key] = snakecase_dict(v) + return new_d + elif isinstance(d, list): + return [snakecase_dict(i) for i in d] + else: + return d diff --git a/src/scaler/version.txt b/src/scaler/version.txt index 53cc1a6f9..f9e8384bb 100644 --- a/src/scaler/version.txt +++ b/src/scaler/version.txt @@ -1 +1 @@ -1.24.0 +1.24.1 diff --git a/src/scaler/worker/worker.py b/src/scaler/worker/worker.py index e420c1b17..204145254 100644 --- a/src/scaler/worker/worker.py +++ b/src/scaler/worker/worker.py @@ -65,6 +65,7 @@ def __init__( logging_paths: Tuple[str, ...], logging_level: str, worker_manager_id: bytes, + deterministic_worker_ids: bool = False, ): multiprocessing.Process.__init__(self, name="Agent") @@ -77,7 +78,10 @@ def __init__( self._io_threads = io_threads self._task_queue_size = task_queue_size - self._ident = WorkerID.generate_worker_id(name) # _identity is internal to multiprocessing.Process + if deterministic_worker_ids: + self._ident = WorkerID(name.encode()) + else: + self._ident = WorkerID.generate_worker_id(name) self._address_path_internal = os.path.join(tempfile.gettempdir(), f"scaler_worker_{uuid.uuid4().hex}") self._address_internal = ZMQConfig(ZMQType.ipc, host=self._address_path_internal) diff --git a/src/scaler/worker_manager_adapter/aws_hpc/worker_manager.py b/src/scaler/worker_manager_adapter/aws_hpc/worker_manager.py new file mode 100644 index 000000000..7e185b029 --- /dev/null +++ b/src/scaler/worker_manager_adapter/aws_hpc/worker_manager.py @@ -0,0 +1,36 @@ +import logging + +from scaler.config.section.aws_hpc_worker_manager import AWSBatchWorkerManagerConfig, AWSHPCBackend +from scaler.worker_manager_adapter.aws_hpc.worker import AWSBatchWorker + + +class AWSHPCWorkerManager: + def __init__(self, config: AWSBatchWorkerManagerConfig) -> None: + self._config = config + + def run(self) -> None: + config = self._config + logging.info(f"Starting AWS HPC Worker Manager (backend: {config.backend.name})") + if config.backend != AWSHPCBackend.batch: + raise NotImplementedError(f"backend {config.backend.name!r} is not yet implemented") + + worker = AWSBatchWorker( + name=config.name or "aws-batch-worker", + address=config.worker_manager_config.scheduler_address, + object_storage_address=config.worker_manager_config.object_storage_address, + job_queue=config.job_queue, + job_definition=config.job_definition, + aws_region=config.aws_region, + s3_bucket=config.s3_bucket, + s3_prefix=config.s3_prefix, + base_concurrency=config.max_concurrent_jobs, + heartbeat_interval_seconds=config.heartbeat_interval_seconds, + death_timeout_seconds=config.death_timeout_seconds, + task_queue_size=config.task_queue_size, + io_threads=config.worker_io_threads, + event_loop=config.event_loop, + job_timeout_seconds=config.job_timeout_minutes * 60, + worker_manager_id=config.worker_manager_id.encode(), + ) + worker.start() + worker.join() diff --git a/src/scaler/worker_manager_adapter/aws_raw/ecs.py b/src/scaler/worker_manager_adapter/aws_raw/ecs.py index 876715037..9d83ae354 100644 --- a/src/scaler/worker_manager_adapter/aws_raw/ecs.py +++ b/src/scaler/worker_manager_adapter/aws_raw/ecs.py @@ -217,7 +217,8 @@ async def _start_ecs_task(self) -> Tuple[List[bytes], Status]: return [], Status.TooManyWorkers command = ( - f"scaler_cluster {self._address.to_address()} " + f"scaler_worker_manager native {self._address.to_address()} " + f"--mode fixed " f"--worker-type ECS " f"--max-task-concurrency {self._ecs_task_cpu} " f"--per-worker-task-queue-size {self._per_worker_task_queue_size} " diff --git a/src/scaler/worker_manager_adapter/orb/__init__.py b/src/scaler/worker_manager_adapter/orb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scaler/worker_manager_adapter/orb/ami/build.sh b/src/scaler/worker_manager_adapter/orb/ami/build.sh new file mode 100755 index 000000000..c073caa06 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/ami/build.sh @@ -0,0 +1,22 @@ +#!/bin/bash +set -e +set -x + +# This script builds the AMI for opengris-scaler using Packer +# It reads the version from the version.txt file and passes it as a variable + +# Get the directory where the script is located +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VERSION_FILE="$SCRIPT_DIR/../../version.txt" + +if [ ! -f "$VERSION_FILE" ]; then + echo "Error: Version file not found at $VERSION_FILE" + exit 1 +fi + +VERSION=$(cat "$VERSION_FILE" | tr -d '[:space:]') + +echo "Building AMI for version: $VERSION" + +cd "$SCRIPT_DIR" +packer build -var "version=$VERSION" opengris-scaler.pkr.hcl diff --git a/src/scaler/worker_manager_adapter/orb/ami/opengris-scaler.pkr.hcl b/src/scaler/worker_manager_adapter/orb/ami/opengris-scaler.pkr.hcl new file mode 100644 index 000000000..8611a939f --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/ami/opengris-scaler.pkr.hcl @@ -0,0 +1,69 @@ +packer { + required_plugins { + amazon = { + version = "~> 1" + source = "github.com/hashicorp/amazon" + } + } +} + +variable "aws_region" { + type = string + default = "us-east-1" +} + +variable "version" { + type = string +} + +variable "ami_regions" { + type = list(string) + default = [] + description = "A list of regions to copy the AMI to." +} + +variable "ami_groups" { + type = list(string) + default = ["all"] + description = "A list of groups to share the AMI with. Set to ['all'] to make public." +} + +variable "python_version" { + type = string + default = "3.13" +} + +source "amazon-ebs" "opengris-scaler" { + ami_name = "opengris-scaler-${var.version}-py${var.python_version}" + instance_type = "t2.small" + region = var.aws_region + ami_regions = var.ami_regions + ami_groups = var.ami_groups + source_ami_filter { + filters = { + name = "al2023-ami-2023.*-kernel-*-x86_64" + root-device-type = "ebs" + virtualization-type = "hvm" + } + most_recent = true + owners = ["amazon"] + } + ssh_username = "ec2-user" +} + +build { + name = "opengris-scaler-build" + sources = ["source.amazon-ebs.opengris-scaler"] + + provisioner "shell" { + inline = [ + "sudo dnf update -y", + "sudo dnf install -y python${var.python_version} python${var.python_version}-pip", + "sudo python${var.python_version} -m venv /opt/opengris-scaler", + "sudo /opt/opengris-scaler/bin/python -m pip install --upgrade pip", + "sudo /opt/opengris-scaler/bin/pip install opengris-scaler==${var.version}", + "sudo ln -sf /opt/opengris-scaler/bin/scaler_* /usr/local/bin/", + "sudo ln -sf /opt/opengris-scaler/bin/python /usr/local/bin/opengris-python" + ] + } +} diff --git a/src/scaler/worker_manager_adapter/orb/exception.py b/src/scaler/worker_manager_adapter/orb/exception.py new file mode 100644 index 000000000..9ae10cbe9 --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/exception.py @@ -0,0 +1,9 @@ +from typing import Any + + +class ORBException(Exception): + """Exception raised for errors in ORB operations.""" + + def __init__(self, data: Any): + self.data = data + super().__init__(f"ORB Exception: {data}") diff --git a/src/scaler/worker_manager_adapter/orb/worker_manager.py b/src/scaler/worker_manager_adapter/orb/worker_manager.py new file mode 100644 index 000000000..22311369e --- /dev/null +++ b/src/scaler/worker_manager_adapter/orb/worker_manager.py @@ -0,0 +1,407 @@ +import asyncio +import logging +import os +import signal +import uuid +from typing import Any, Dict, List, Optional, Tuple + +import boto3 +import zmq +from orb import ORBClient as orb + +from scaler.config.section.orb_worker_adapter import ORBWorkerAdapterConfig +from scaler.io import ymq +from scaler.io.mixins import AsyncConnector +from scaler.io.utility import create_async_connector, create_async_simple_context +from scaler.protocol.python.message import ( + Message, + WorkerManagerCommand, + WorkerManagerCommandResponse, + WorkerManagerCommandType, + WorkerManagerHeartbeat, + WorkerManagerHeartbeatEcho, +) +from scaler.utility.event_loop import create_async_loop_routine, register_event_loop, run_task_forever +from scaler.utility.identifiers import WorkerID +from scaler.utility.logging.utility import setup_logger +from scaler.worker_manager_adapter.common import format_capabilities + +Status = WorkerManagerCommandResponse.Status +logger = logging.getLogger(__name__) + + +# Polling configuration for ORB machine requests +ORB_POLLING_INTERVAL_SECONDS = 5 +ORB_MAX_POLLING_ATTEMPTS = 60 + + +def get_orb_worker_name(instance_id: str) -> str: + """ + Returns the deterministic worker name for an ORB instance. + If instance_id is the bash variable '${INSTANCE_ID}', it returns a bash-compatible string. + """ + if instance_id == "${INSTANCE_ID}": + return "Worker|ORB|${INSTANCE_ID}|${INSTANCE_ID//i-/}" + tag = instance_id.replace("i-", "") + return f"Worker|ORB|{instance_id}|{tag}" + + +class ORBWorkerAdapter: + _config: ORBWorkerAdapterConfig + _sdk: Optional[Any] + _workers: Dict[WorkerID, str] + _template_id: str + _created_security_group_id: Optional[str] + _created_key_name: Optional[str] + _ec2: Optional[Any] + + def __init__(self, config: ORBWorkerAdapterConfig): + self._config = config + self._address = config.worker_manager_config.scheduler_address + self._heartbeat_interval_seconds = config.worker_config.heartbeat_interval_seconds + self._capabilities = config.worker_config.per_worker_capabilities.capabilities + self._max_task_concurrency = config.worker_manager_config.max_task_concurrency + + self._event_loop = config.event_loop + self._logging_paths = config.logging_config.paths + self._logging_level = config.logging_config.level + self._logging_config_file = config.logging_config.config_file + + self._sdk: Optional[Any] = None + self._ec2: Optional[Any] = None + self._context = None + self._connector_external: Optional[AsyncConnector] = None + self._created_security_group_id: Optional[str] = None + self._created_key_name: Optional[str] = None + self._cleaned_up = False + self._workers: Dict[WorkerID, str] = {} + self._ident: bytes = b"worker_manager_orb|uninitialized" + self._subnet_id: Optional[str] = None + + def _build_app_config(self) -> dict: + region = self._config.aws_region or "us-east-1" + return { + "provider": { + "selection_policy": "FIRST_AVAILABLE", + "providers": [ + {"name": "aws-default", "type": "aws", "enabled": True, "priority": 1, "config": {"region": region}} + ], + }, + "storage": {"type": "json"}, + } + + async def __setup(self) -> None: + """Set up AWS resources and the ORB template after the SDK is initialised.""" + region = self._config.aws_region or "us-east-1" + self._ec2 = boto3.client("ec2", region_name=region) + self._subnet_id = self._config.subnet_id or self._discover_default_subnet() + self._template_id = os.urandom(8).hex() + + security_group_ids = self._config.security_group_ids + if not security_group_ids: + self._create_security_group() + security_group_ids = [self._created_security_group_id] + + key_name = self._config.key_name + if not key_name: + self._create_key_pair() + key_name = self._created_key_name + + user_data = self._create_user_data() + + create_result = await self._sdk.create_template( + template_id=self._template_id, + name=f"opengris-orb-{self._template_id}", + image_id=self._config.image_id, + provider_api="RunInstances", + instance_type=self._config.instance_type, + max_instances=self._config.worker_manager_config.max_task_concurrency, + provider_name="aws-default", + machine_types={self._config.instance_type: 1}, + subnet_ids=[self._subnet_id], + security_group_ids=security_group_ids, + key_name=key_name, + user_data=user_data, + ) + logger.info(f"create_template result: {create_result}") + + validate_result = await self._sdk.validate_template(template_id=self._template_id) + logger.info(f"validate_template result: {validate_result}") + + self._context = create_async_simple_context() + self._name = "worker_manager_orb" + self._ident = f"{self._name}|{uuid.uuid4().bytes.hex()}".encode() + + self._connector_external = create_async_connector( + self._context, + name=self._name, + socket_type=zmq.DEALER, + address=self._address, + bind_or_connect="connect", + callback=self.__on_receive_external, + identity=self._ident, + ) + + async def __terminate_all_workers(self) -> None: + """Return all active instances to ORB before the SDK context exits.""" + if not self._workers or self._sdk is None: + return + instance_ids = list(self._workers.values()) + logger.info(f"Terminating {len(instance_ids)} worker group(s)...") + try: + await self._sdk.create_return_request(machine_ids=instance_ids) + logger.info(f"Successfully requested termination of instances: {instance_ids}") + except Exception as e: + logger.warning(f"Failed to terminate instances during cleanup: {e}") + self._workers.clear() + + async def __on_receive_external(self, message: Message): + if isinstance(message, WorkerManagerCommand): + await self._handle_command(message) + elif isinstance(message, WorkerManagerHeartbeatEcho): + pass + else: + logging.warning(f"Received unknown message type: {type(message)}") + + async def _handle_command(self, command: WorkerManagerCommand): + cmd_type = command.command + response_status = Status.Success + worker_ids: List[bytes] = [] + capabilities: Dict[str, int] = {} + + if cmd_type == WorkerManagerCommandType.StartWorkers: + worker_ids, response_status = await self.start_worker() + if response_status == Status.Success: + capabilities = self._capabilities + elif cmd_type == WorkerManagerCommandType.ShutdownWorkers: + worker_ids, response_status = await self.shutdown_workers(list(command.worker_ids)) + else: + raise ValueError("Unknown Command") + + assert self._connector_external is not None + await self._connector_external.send( + WorkerManagerCommandResponse.new_msg( + command=cmd_type, status=response_status, worker_ids=worker_ids, capabilities=capabilities + ) + ) + + async def __send_heartbeat(self): + assert self._connector_external is not None + await self._connector_external.send( + WorkerManagerHeartbeat.new_msg( + max_task_concurrency=self._max_task_concurrency, + capabilities=self._capabilities, + worker_manager_id=self._ident, + ) + ) + + def run(self) -> None: + self._loop = asyncio.new_event_loop() + run_task_forever(self._loop, self._run(), cleanup_callback=self._cleanup) + + def __destroy(self): + print(f"Worker adapter {self._ident!r} received signal, shutting down") + self._task.cancel() + + def __register_signal(self): + self._loop.add_signal_handler(signal.SIGINT, self.__destroy) + self._loop.add_signal_handler(signal.SIGTERM, self.__destroy) + + async def _run(self) -> None: + register_event_loop(self._event_loop) + setup_logger(self._logging_paths, self._logging_config_file, self._logging_level) + + async with orb(app_config=self._build_app_config()) as sdk: + self._sdk = sdk + await self.__setup() + self._task = self._loop.create_task(self.__get_loops()) + self.__register_signal() + try: + await self._task + except asyncio.CancelledError: + pass + finally: + await self.__terminate_all_workers() + + self._sdk = None + + async def __get_loops(self): + assert self._connector_external is not None + loops = [ + create_async_loop_routine(self._connector_external.routine, 0), + create_async_loop_routine(self.__send_heartbeat, self._heartbeat_interval_seconds), + ] + + try: + await asyncio.gather(*loops) + except asyncio.CancelledError: + pass + except ymq.YMQException as e: + if e.code == ymq.ErrorCode.ConnectorSocketClosedByRemoteEnd: + pass + else: + logging.exception(f"{self._ident!r}: failed with unhandled exception:\n{e}") + + def _create_user_data(self) -> str: + worker_config = self._config.worker_config + adapter_config = self._config.worker_manager_config + + # NOTE: --num-of-workers is not passed; NativeWorkerManager defaults to cpu_count - 1 + # workers in fixed mode, where cpu_count is determined by the machine type configured by the user. + script = f"""#!/bin/bash +nohup /usr/local/bin/scaler_worker_manager native {adapter_config.scheduler_address.to_address()} \ + --mode fixed \ + --worker-type ORB \ + --per-worker-task-queue-size {worker_config.per_worker_task_queue_size} \ + --heartbeat-interval-seconds {worker_config.heartbeat_interval_seconds} \ + --task-timeout-seconds {worker_config.task_timeout_seconds} \ + --garbage-collect-interval-seconds {worker_config.garbage_collect_interval_seconds} \ + --death-timeout-seconds {worker_config.death_timeout_seconds} \ + --trim-memory-threshold-bytes {worker_config.trim_memory_threshold_bytes} \ + --event-loop {self._config.event_loop} \ + --worker-io-threads {self._config.worker_io_threads}""" + + if worker_config.hard_processor_suspend: + script += " \ + --hard-processor-suspend" + + if adapter_config.object_storage_address: + script += f" \ + --object-storage-address {adapter_config.object_storage_address.to_string()}" + + capabilities = worker_config.per_worker_capabilities.capabilities + if capabilities: + cap_str = format_capabilities(capabilities) + if cap_str.strip(): + script += f" \ + --per-worker-capabilities {cap_str}" + + script += " > /var/log/opengris-scaler.log 2>&1 &\n" + + return script + + def _discover_default_subnet(self) -> str: + vpcs = self._ec2.describe_vpcs(Filters=[{"Name": "isDefault", "Values": ["true"]}]) + if not vpcs["Vpcs"]: + raise RuntimeError("No default VPC found, and no subnet_id provided.") + default_vpc_id = vpcs["Vpcs"][0]["VpcId"] + + subnets = self._ec2.describe_subnets(Filters=[{"Name": "vpc-id", "Values": [default_vpc_id]}]) + if not subnets["Subnets"]: + raise RuntimeError(f"No subnets found in default VPC {default_vpc_id}.") + + subnet_id = subnets["Subnets"][0]["SubnetId"] + logger.info(f"Auto-discovered subnet_id: {subnet_id}") + return subnet_id + + def _create_security_group(self): + # Get VPC ID from Subnet + subnet_response = self._ec2.describe_subnets(SubnetIds=[self._subnet_id]) + vpc_id = subnet_response["Subnets"][0]["VpcId"] + + # Create Security Group (outbound-only — workers connect out to scheduler via ZMQ) + group_name = f"opengris-orb-sg-{self._template_id}" + sg_response = self._ec2.create_security_group( + Description="Temporary security group created for OpenGRIS ORB worker adapter", + GroupName=group_name, + VpcId=vpc_id, + ) + self._created_security_group_id = sg_response["GroupId"] + logger.info(f"Created security group with ID: {self._created_security_group_id}") + + def _create_key_pair(self): + key_name = f"opengris-orb-key-{self._template_id}" + self._ec2.create_key_pair(KeyName=key_name) + self._created_key_name = key_name + logger.info(f"Created key pair: {key_name}") + + def _cleanup(self): + if self._cleaned_up: + return + self._cleaned_up = True + + if self._connector_external is not None: + self._connector_external.destroy() + + logger.info("Starting cleanup of AWS resources...") + + if self._created_security_group_id is not None: + try: + logger.info(f"Deleting AWS security group: {self._created_security_group_id}") + self._ec2.delete_security_group(GroupId=self._created_security_group_id) + except Exception as e: + logger.warning(f"Failed to delete security group {self._created_security_group_id}: {e}") + + if self._created_key_name is not None: + try: + logger.info(f"Deleting AWS key pair: {self._created_key_name}") + self._ec2.delete_key_pair(KeyName=self._created_key_name) + except Exception as e: + logger.warning(f"Failed to delete key pair {self._created_key_name}: {e}") + + logger.info("Cleanup completed.") + + def __del__(self): + self._cleanup() + + async def start_worker(self) -> Tuple[List[bytes], Status]: + if len(self._workers) >= self._max_task_concurrency != -1: + return [], Status.TooManyWorkers + + response = await self._sdk.create_request(template_id=self._template_id, count=1) + request_id = ( + response.get("created_request_id") or response.get("request_id") or response.get("id") + if isinstance(response, dict) + else None + ) + + if not request_id: + logger.error(f"ORB machine request failed to return a request ID. Response: {response}") + return [], Status.UnknownAction + + logger.info(f"ORB machine request {request_id} submitted, waiting for instance IDs...") + + timeout = float(ORB_MAX_POLLING_ATTEMPTS * ORB_POLLING_INTERVAL_SECONDS) + try: + final = await self._sdk.wait_for_request( + request_id, timeout=timeout, poll_interval=float(ORB_POLLING_INTERVAL_SECONDS) + ) + except TimeoutError: + logger.error(f"ORB machine request {request_id} timed out after {timeout:.0f}s.") + return [], Status.UnknownAction + + machines = final.get("machines", []) if isinstance(final, dict) else [] + instance_id = next( + (m.get("machine_id") or m.get("id") for m in machines if m.get("machine_id") or m.get("id")), None + ) + + if not instance_id: + status = final.get("status", "") if isinstance(final, dict) else "" + logger.error(f"ORB request {request_id} completed with status '{status}' but no instance ID found.") + return [], Status.UnknownAction + + logger.info(f"ORB request {request_id} fulfilled with instance ID: {instance_id}") + worker_id = WorkerID(get_orb_worker_name(instance_id).encode()) + self._workers[worker_id] = instance_id + return [bytes(worker_id)], Status.Success + + async def shutdown_workers(self, worker_ids: List[bytes]) -> Tuple[List[bytes], Status]: + if not worker_ids: + return [], Status.WorkerNotFound + + instance_ids = [] + affected_worker_ids = [] + for wid_bytes in worker_ids: + worker_id = WorkerID(wid_bytes) + if worker_id not in self._workers: + logger.warning(f"Worker with ID {wid_bytes!r} does not exist.") + return [], Status.WorkerNotFound + instance_ids.append(self._workers[worker_id]) + affected_worker_ids.append(wid_bytes) + + await self._sdk.create_return_request(machine_ids=instance_ids) + + for wid_bytes in affected_worker_ids: + del self._workers[WorkerID(wid_bytes)] + + return affected_worker_ids, Status.Success diff --git a/tests/config/test_config_class.py b/tests/config/test_config_class.py index 59ffb97ea..bbbb00950 100644 --- a/tests/config/test_config_class.py +++ b/tests/config/test_config_class.py @@ -78,7 +78,7 @@ class MyConfig(ConfigClass): renamed: int = dataclasses.field(default=0, metadata=dict(long="--new-name")) parser = MockArgParser() - MyConfig.configure_parser(parser) + MyConfig.configure_parser(parser) # type: ignore[arg-type] args = parser.args # Q: What is the "dest" kwarg? @@ -136,13 +136,15 @@ class MyConfigClass(ConfigClass): @patch.dict("os.environ", {"ENV_VAR_ONE": "99", "ENV_VAR_TWO": "98"}) @patch( "builtins.open", - mock_open(read_data=""" + mock_open( + read_data=""" [my_config] config-file = 99 [unused_section] another-one = 97 - """), + """ + ), ) def test_precedence(self) -> None: @dataclasses.dataclass @@ -192,11 +194,13 @@ class MyConfig(ConfigClass): @patch("sys.argv", ["script", "--config", "file"]) @patch( "builtins.open", - mock_open(read_data=""" + mock_open( + read_data=""" [my_config] my_int = 10 my-other-int = 20 - """), + """ + ), ) def test_underscore_toml_parsing(self) -> None: @dataclasses.dataclass diff --git a/tests/entry_points/__init__.py b/tests/entry_points/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/entry_points/test_aio.py b/tests/entry_points/test_aio.py new file mode 100644 index 000000000..e09bbfa2d --- /dev/null +++ b/tests/entry_points/test_aio.py @@ -0,0 +1,142 @@ +import dataclasses +import unittest +from typing import List, Optional +from unittest.mock import MagicMock, patch + +from scaler.config.config_class import ConfigClass, _reconstruct_config + +# --------------------------------------------------------------------------- +# Minimal stub configs for section= tests +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _SimpleSchedulerConfig(ConfigClass): + host: str = "localhost" + port: int = 8516 + + +@dataclasses.dataclass +class _SimpleWorkerConfig(ConfigClass): + workers: int = 1 + + +@dataclasses.dataclass +class _SectionTestConfig(ConfigClass): + scheduler: Optional[_SimpleSchedulerConfig] = dataclasses.field(default=None, metadata=dict(section="scheduler")) + workers: List[_SimpleWorkerConfig] = dataclasses.field(default_factory=list, metadata=dict(section="workers")) + + +class TestSectionMetadata(unittest.TestCase): + """Tests the section= metadata path in ConfigClass / _reconstruct_config.""" + + def _build(self, toml_data): + return _reconstruct_config(_SectionTestConfig, {}, toml_data) + + def test_single_table_populates_optional(self) -> None: + toml = {"scheduler": {"host": "192.168.1.1", "port": 9999}} + config = self._build(toml) + self.assertIsNotNone(config.scheduler) + self.assertEqual(config.scheduler.host, "192.168.1.1") + self.assertEqual(config.scheduler.port, 9999) + + def test_absent_section_gives_none(self) -> None: + config = self._build({}) + self.assertIsNone(config.scheduler) + + def test_absent_list_section_gives_empty_list(self) -> None: + config = self._build({}) + self.assertEqual(config.workers, []) + + def test_single_dict_in_list_section_gives_one_element(self) -> None: + toml = {"workers": {"workers": 4}} + config = self._build(toml) + self.assertEqual(len(config.workers), 1) + self.assertEqual(config.workers[0].workers, 4) + + def test_array_of_tables_gives_multiple_elements(self) -> None: + toml = {"workers": [{"workers": 2}, {"workers": 8}]} + config = self._build(toml) + self.assertEqual(len(config.workers), 2) + self.assertEqual(config.workers[0].workers, 2) + self.assertEqual(config.workers[1].workers, 8) + + def test_both_sections_populated(self) -> None: + toml = {"scheduler": {"host": "10.0.0.1", "port": 1234}, "workers": [{"workers": 3}]} + config = self._build(toml) + self.assertIsNotNone(config.scheduler) + self.assertEqual(config.scheduler.host, "10.0.0.1") + self.assertEqual(len(config.workers), 1) + self.assertEqual(config.workers[0].workers, 3) + + +# --------------------------------------------------------------------------- +# Minimal _AIOConfig-like stub for end-to-end tests +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _MiniAIOConfig(ConfigClass): + scheduler: Optional[_SimpleSchedulerConfig] = dataclasses.field(default=None, metadata=dict(section="scheduler")) + workers: List[_SimpleWorkerConfig] = dataclasses.field(default_factory=list, metadata=dict(section="workers")) + + +class TestAIOEndToEnd(unittest.TestCase): + """End-to-end tests for the section= flow through ConfigClass.parse.""" + + def _make_toml_data(self, data): + """Return a mock _load_toml that yields the given data dict.""" + return patch("scaler.config.config_class._load_toml", return_value=data) + + @patch("sys.argv", ["scaler_aio", "--config", "test.toml"]) + def test_no_recognized_sections_exits(self) -> None: + with self._make_toml_data({}): + config = _MiniAIOConfig.parse("scaler_aio", "aio") + self.assertIsNone(config.scheduler) + self.assertEqual(config.workers, []) + + @patch("sys.argv", ["scaler_aio", "--config", "test.toml"]) + def test_scheduler_section_populated(self) -> None: + toml = {"scheduler": {"host": "127.0.0.1", "port": 8516}} + with self._make_toml_data(toml): + config = _MiniAIOConfig.parse("scaler_aio", "aio") + self.assertIsNotNone(config.scheduler) + self.assertEqual(config.scheduler.host, "127.0.0.1") + + @patch("sys.argv", ["scaler_aio", "--config", "test.toml"]) + def test_workers_list_populated(self) -> None: + toml = {"workers": [{"workers": 4}, {"workers": 8}]} + with self._make_toml_data(toml): + config = _MiniAIOConfig.parse("scaler_aio", "aio") + self.assertEqual(len(config.workers), 2) + + @patch("sys.argv", ["scaler_aio", "--help"]) + def test_help_exits(self) -> None: + with self.assertRaises(SystemExit): + _MiniAIOConfig.parse("scaler_aio", "aio") + + +class TestAIOMain(unittest.TestCase): + """Tests for scaler_aio main() process spawning logic.""" + + def _run_main_with_toml(self, toml_data): + from scaler.entry_points.aio import main + + with patch("scaler.config.config_class._load_toml", return_value=toml_data), patch( + "multiprocessing.Process" + ) as mock_process_cls: + mock_proc = MagicMock() + mock_process_cls.return_value = mock_proc + with patch("sys.argv", ["scaler_aio", "--config", "test.toml"]): + main() + return mock_process_cls, mock_proc + + def test_no_sections_exits_with_code_1(self) -> None: + from scaler.entry_points.aio import main + + with patch("scaler.config.config_class._load_toml", return_value={}), patch( + "sys.argv", ["scaler_aio", "--config", "test.toml"] + ): + with self.assertRaises(SystemExit) as ctx: + main() + self.assertEqual(ctx.exception.code, 1) diff --git a/tests/entry_points/test_worker_manager.py b/tests/entry_points/test_worker_manager.py new file mode 100644 index 000000000..1635811be --- /dev/null +++ b/tests/entry_points/test_worker_manager.py @@ -0,0 +1,153 @@ +import dataclasses +import unittest +from typing import Optional +from unittest.mock import mock_open, patch + +from scaler.config.config_class import ConfigClass + + +@dataclasses.dataclass +class _LeafConfig(ConfigClass): + value: int = 0 + name: str = "default" + + +@dataclasses.dataclass +class _RootConfig(ConfigClass): + foo: Optional[_LeafConfig] = dataclasses.field(default=None, metadata=dict(subcommand="foo_section")) + bar: Optional[_LeafConfig] = dataclasses.field(default=None, metadata=dict(subcommand="bar_section")) + + +@dataclasses.dataclass +class _RootWithCommonConfig(ConfigClass): + log_level: str = "INFO" + foo: Optional[_LeafConfig] = dataclasses.field(default=None, metadata=dict(subcommand="foo_section")) + bar: Optional[_LeafConfig] = dataclasses.field(default=None, metadata=dict(subcommand="bar_section")) + + +# Two-level nested subcommands for nesting tests +@dataclasses.dataclass +class _Level2Config(ConfigClass): + depth: int = 2 + + +@dataclasses.dataclass +class _Level1Config(ConfigClass): + inner: Optional[_Level2Config] = dataclasses.field(default=None, metadata=dict(subcommand="level2_section")) + + +@dataclasses.dataclass +class _NestedRootConfig(ConfigClass): + level1: Optional[_Level1Config] = dataclasses.field(default=None, metadata=dict(subcommand="level1_section")) + + +class TestWorkerManagerSubcommands(unittest.TestCase): + """Tests the subcommand= metadata path in ConfigClass.""" + + @patch("sys.argv", ["prog", "foo", "--value", "42"]) + def test_foo_subcommand_selected(self) -> None: + config = _RootConfig.parse("prog", "") + self.assertIsNotNone(config.foo) + self.assertIsNone(config.bar) + self.assertEqual(config.foo.value, 42) + + @patch("sys.argv", ["prog", "bar", "--value", "7"]) + def test_bar_subcommand_selected(self) -> None: + config = _RootConfig.parse("prog", "") + self.assertIsNone(config.foo) + self.assertIsNotNone(config.bar) + self.assertEqual(config.bar.value, 7) + + @patch("sys.argv", ["prog", "foo"]) + def test_default_values_used(self) -> None: + config = _RootConfig.parse("prog", "") + self.assertIsNotNone(config.foo) + self.assertEqual(config.foo.value, 0) + self.assertEqual(config.foo.name, "default") + + @patch("sys.argv", ["prog", "foo", "--log-level", "DEBUG"]) + def test_root_level_fields_populated(self) -> None: + config = _RootWithCommonConfig.parse("prog", "") + self.assertEqual(config.log_level, "DEBUG") + self.assertIsNotNone(config.foo) + + @patch("sys.argv", ["prog", "foo", "--value", "5"]) + @patch( + "builtins.open", + mock_open( + read_data=""" + [foo_section] + value = 99 + name = "from_toml" + """ + ), + ) + def test_cli_overrides_toml(self) -> None: + with patch("sys.argv", ["prog", "--config", "cfg.toml", "foo", "--value", "5"]): + config = _RootConfig.parse("prog", "") + self.assertIsNotNone(config.foo) + # CLI --value 5 should override TOML value 99 + self.assertEqual(config.foo.value, 5) + # name not provided on CLI → TOML value used + self.assertEqual(config.foo.name, "from_toml") + + @patch( + "builtins.open", + mock_open( + read_data=""" + [foo_section] + value = 77 + """ + ), + ) + def test_config_after_subcommand(self) -> None: + """--config appearing after the sub-command name must still be loaded.""" + with patch("sys.argv", ["prog", "foo", "--config", "cfg.toml"]): + config = _RootConfig.parse("prog", "") + self.assertIsNotNone(config.foo) + self.assertEqual(config.foo.value, 77) + + @patch( + "builtins.open", + mock_open( + read_data=""" + [foo_section] + value = 55 + """ + ), + ) + def test_config_before_subcommand(self) -> None: + """--config appearing before the sub-command name must still be loaded.""" + with patch("sys.argv", ["prog", "--config", "cfg.toml", "foo"]): + config = _RootConfig.parse("prog", "") + self.assertIsNotNone(config.foo) + self.assertEqual(config.foo.value, 55) + + @patch("sys.argv", ["prog"]) + def test_no_subcommand_exits(self) -> None: + with self.assertRaises(SystemExit): + _RootConfig.parse("prog", "") + + @patch("sys.argv", ["prog", "bad_cmd"]) + def test_unknown_subcommand_exits(self) -> None: + with self.assertRaises(SystemExit): + _RootConfig.parse("prog", "") + + @patch("sys.argv", ["prog", "--help"]) + def test_help_exits(self) -> None: + with self.assertRaises(SystemExit): + _RootConfig.parse("prog", "") + + @patch("sys.argv", ["prog", "level1", "inner", "--depth", "99"]) + def test_nested_subcommands_route_correctly(self) -> None: + config = _NestedRootConfig.parse("prog", "") + self.assertIsNotNone(config.level1) + self.assertIsNotNone(config.level1.inner) + self.assertEqual(config.level1.inner.depth, 99) + + @patch("sys.argv", ["prog", "level1", "inner"]) + def test_nested_subcommands_unselected_are_none(self) -> None: + config = _NestedRootConfig.parse("prog", "") + self.assertIsNotNone(config.level1) + self.assertIsNotNone(config.level1.inner) + self.assertEqual(config.level1.inner.depth, 2) # default diff --git a/tests/io/uv_ymq/test_sockets.py b/tests/io/uv_ymq/test_sockets.py index a649827ae..f30b63c5c 100644 --- a/tests/io/uv_ymq/test_sockets.py +++ b/tests/io/uv_ymq/test_sockets.py @@ -1,7 +1,7 @@ import asyncio import unittest -from scaler.io.ymq import BinderSocket, Bytes, ConnectorSocket, ErrorCode, IOContext, InvalidAddressFormatError +from scaler.io.ymq import BinderSocket, Bytes, ConnectorSocket, ErrorCode, InvalidAddressFormatError, IOContext class TestSockets(unittest.IsolatedAsyncioTestCase): diff --git a/tests/io/uv_ymq/test_types.py b/tests/io/uv_ymq/test_types.py index 37e8e7012..65c55682a 100644 --- a/tests/io/uv_ymq/test_types.py +++ b/tests/io/uv_ymq/test_types.py @@ -26,7 +26,7 @@ def test_error_code(self): self.assertTrue(issubclass(ErrorCode, IntEnum)) # type: ignore self.assertEqual( ErrorCode.InvalidAddressFormat.explanation(), - "Invalid address format, example input \"tcp://127.0.0.1:2345\" or \"ipc:///tmp/domain_socket_name.sock\"", + 'Invalid address format, example input "tcp://127.0.0.1:2345" or "ipc:///tmp/domain_socket_name.sock"', ) def test_bytes(self):