Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions python/ray/scripts/symmetric_run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""Symmetric Run for Ray."""
"""Symmetric Run for Ray.

This script is intended for environments where the same command is executed on every node
(e.g. SLURM). It decides whether a given node should be the head or a worker, starts Ray, runs
the entrypoint on the head, and then shuts the cluster down.
"""

import socket
import subprocess
Expand All @@ -11,13 +16,18 @@
import ray
from ray._private.ray_constants import env_integer
from ray._raylet import GcsClient
from ray.exceptions import RpcError

import psutil

CLUSTER_WAIT_TIMEOUT = env_integer("RAY_SYMMETRIC_RUN_CLUSTER_WAIT_TIMEOUT", 30)


class SymmetricRunCommand(click.Command):
def parse_args(self, ctx, args):
ctx.meta["raw_args"] = list(args)
return super().parse_args(ctx, args)


def check_ray_already_started() -> bool:
import ray._private.services as services

Expand All @@ -26,14 +36,14 @@ def check_ray_already_started() -> bool:
return len(running_gcs_addresses) > 0


def check_cluster_ready(nnodes, timeout=CLUSTER_WAIT_TIMEOUT):
def check_cluster_ready(address, nnodes, timeout=CLUSTER_WAIT_TIMEOUT):
"""Wait for all nodes to start.

Raises an exception if the nodes don't start in time.
"""
start_time = time.time()
current_nodes = 1
ray.init(ignore_reinit_error=True)
ray.init(address=address, ignore_reinit_error=True)

while time.time() - start_time < timeout:
time.sleep(5)
Expand All @@ -48,14 +58,19 @@ def check_cluster_ready(nnodes, timeout=CLUSTER_WAIT_TIMEOUT):


def check_head_node_ready(address: str, timeout=CLUSTER_WAIT_TIMEOUT):
from ray.exceptions import RpcError

start_time = time.time()
gcs_client = GcsClient(address=address)
while time.time() - start_time < timeout:
try:
# check_alive returns a list of dead node IDs from the input list.
# If it runs without raising an exception, the GCS is ready.
gcs_client.check_alive([], timeout=1)
click.echo("Ray cluster is ready!")
return True
except RpcError:
# GCS not ready yet, keep retrying
pass
time.sleep(5)
return False
Expand All @@ -67,22 +82,24 @@ def curate_and_validate_ray_start_args(run_and_start_args: List[str]) -> List[st
cleaned_args = list(ctx.params["ray_args_and_entrypoint"])

for arg in cleaned_args:
if arg == "--head":
normalized = arg.split("=", 1)[0]
if normalized == "--head":
raise click.ClickException("Cannot use --head option in symmetric_run.")
if arg == "--node-ip-address":
if normalized == "--node-ip-address":
raise click.ClickException(
"Cannot use --node-ip-address option in symmetric_run."
)
if arg == "--port":
if normalized == "--port":
raise click.ClickException("Cannot use --port option in symmetric_run.")
if arg == "--block":
if normalized == "--block":
raise click.ClickException("Cannot use --block option in symmetric_run.")

return cleaned_args


@click.command(
name="symmetric_run",
cls=SymmetricRunCommand,
context_settings={"ignore_unknown_options": True, "allow_extra_args": True},
help="""Command to start Ray across all nodes and execute an entrypoint command.

Expand Down Expand Up @@ -135,24 +152,20 @@ def curate_and_validate_ray_start_args(run_and_start_args: List[str]) -> List[st
help="If provided, wait for this number of nodes to start.",
)
@click.argument("ray_args_and_entrypoint", nargs=-1, type=click.UNPROCESSED)
def symmetric_run(address, min_nodes, ray_args_and_entrypoint):
all_args = sys.argv[1:]

if all_args and all_args[0] == "symmetric-run":
all_args = all_args[1:]

@click.pass_context
def symmetric_run(ctx, address, min_nodes, ray_args_and_entrypoint):
raw_args = ctx.meta.get("raw_args", [])
try:
separator = all_args.index("--")
separator = raw_args.index("--")
except ValueError:
raise click.ClickException(
"No separator '--' found in arguments. Please use '--' to "
"separate Ray start arguments and the entrypoint command."
f" Got arguments: {raw_args}"
)

run_and_start_args, entrypoint_on_head = (
all_args[:separator],
all_args[separator + 1 :],
)
run_and_start_args = raw_args[:separator]
entrypoint_on_head = raw_args[separator + 1 :]

ray_start_args = curate_and_validate_ray_start_args(run_and_start_args)

Expand Down Expand Up @@ -219,7 +232,7 @@ def symmetric_run(address, min_nodes, ray_args_and_entrypoint):
subprocess.run(ray_start_cmd, check=True, capture_output=True)
click.echo("Head node started.")
click.echo("=======================")
if min_nodes > 1 and not check_cluster_ready(min_nodes):
if min_nodes > 1 and not check_cluster_ready(address, min_nodes):
raise click.ClickException(
"Timed out waiting for other nodes to start."
)
Expand Down
24 changes: 24 additions & 0 deletions python/ray/tests/symmetric_run_test_entrypoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Test entrypoint for symmetric_run multi-node integration test.

This script is executed by symmetric_run on the head node during
test_symmetric_run_three_node_cluster_simulated. It connects to the
Ray cluster and verifies that all expected nodes have joined.
"""
import sys
import time

import ray

EXPECTED_NODES = 3
TIMEOUT_SECONDS = 60

ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)

for _ in range(TIMEOUT_SECONDS):
if len(ray.nodes()) >= EXPECTED_NODES:
print("ENTRYPOINT_SUCCESS")
sys.exit(0)
time.sleep(1)

print(f"ENTRYPOINT_FAILED: cluster_size={len(ray.nodes())}")
sys.exit(1)
43 changes: 43 additions & 0 deletions python/ray/tests/symmetric_run_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Wrapper to run symmetric_run with mocked network interfaces.

This wrapper is needed because symmetric_run spawns subprocesses (ray start,
entrypoint), so mocks applied in pytest don't propagate. This wrapper applies
patches before invoking symmetric_run, allowing us to control what each
'node' sees as its local IP.

If MOCK_HIDE_IP is set, filter out this IP from network interfaces so the
process thinks it's not on the head node.
"""
import os
from unittest.mock import patch

import ray # noqa: F401 - must import ray first, psutil is vendored
from ray.scripts.symmetric_run import symmetric_run

import psutil

_real_net_if_addrs = psutil.net_if_addrs


def _mocked_net_if_addrs():
"""Return network interfaces, optionally hiding the head IP."""
real_addrs = _real_net_if_addrs()
hide_ip = os.environ.get("MOCK_HIDE_IP")
if not hide_ip:
return real_addrs

new_addrs = {}
for iface, addrs in real_addrs.items():
# Filter out the IP to be hidden from the list of addresses for this interface.
filtered_addrs = [addr for addr in addrs if addr.address != hide_ip]
if filtered_addrs:
new_addrs[iface] = filtered_addrs
return new_addrs


if __name__ == "__main__":
with (
patch("ray._private.services.find_gcs_addresses", return_value=[]),
patch("psutil.net_if_addrs", side_effect=_mocked_net_if_addrs),
):
symmetric_run()
Loading