diff --git a/python/ray/scripts/symmetric_run.py b/python/ray/scripts/symmetric_run.py index f1bf2a5987e9..468e7bf740e7 100644 --- a/python/ray/scripts/symmetric_run.py +++ b/python/ray/scripts/symmetric_run.py @@ -1,4 +1,9 @@ -"""Symmetric Run for Ray.""" +"""Symmetric Run for Ray. + +This script is intended for environments where the same command is executed on every node +(e.g. SLURM). It decides whether a given node should be the head or a worker, starts Ray, runs +the entrypoint on the head, and then shuts the cluster down. +""" import socket import subprocess @@ -11,13 +16,18 @@ import ray from ray._private.ray_constants import env_integer from ray._raylet import GcsClient -from ray.exceptions import RpcError import psutil CLUSTER_WAIT_TIMEOUT = env_integer("RAY_SYMMETRIC_RUN_CLUSTER_WAIT_TIMEOUT", 30) +class SymmetricRunCommand(click.Command): + def parse_args(self, ctx, args): + ctx.meta["raw_args"] = list(args) + return super().parse_args(ctx, args) + + def check_ray_already_started() -> bool: import ray._private.services as services @@ -26,14 +36,14 @@ def check_ray_already_started() -> bool: return len(running_gcs_addresses) > 0 -def check_cluster_ready(nnodes, timeout=CLUSTER_WAIT_TIMEOUT): +def check_cluster_ready(address, nnodes, timeout=CLUSTER_WAIT_TIMEOUT): """Wait for all nodes to start. Raises an exception if the nodes don't start in time. """ start_time = time.time() current_nodes = 1 - ray.init(ignore_reinit_error=True) + ray.init(address=address, ignore_reinit_error=True) while time.time() - start_time < timeout: time.sleep(5) @@ -48,14 +58,19 @@ def check_cluster_ready(nnodes, timeout=CLUSTER_WAIT_TIMEOUT): def check_head_node_ready(address: str, timeout=CLUSTER_WAIT_TIMEOUT): + from ray.exceptions import RpcError + start_time = time.time() gcs_client = GcsClient(address=address) while time.time() - start_time < timeout: try: + # check_alive returns a list of dead node IDs from the input list. + # If it runs without raising an exception, the GCS is ready. gcs_client.check_alive([], timeout=1) click.echo("Ray cluster is ready!") return True except RpcError: + # GCS not ready yet, keep retrying pass time.sleep(5) return False @@ -67,15 +82,16 @@ def curate_and_validate_ray_start_args(run_and_start_args: List[str]) -> List[st cleaned_args = list(ctx.params["ray_args_and_entrypoint"]) for arg in cleaned_args: - if arg == "--head": + normalized = arg.split("=", 1)[0] + if normalized == "--head": raise click.ClickException("Cannot use --head option in symmetric_run.") - if arg == "--node-ip-address": + if normalized == "--node-ip-address": raise click.ClickException( "Cannot use --node-ip-address option in symmetric_run." ) - if arg == "--port": + if normalized == "--port": raise click.ClickException("Cannot use --port option in symmetric_run.") - if arg == "--block": + if normalized == "--block": raise click.ClickException("Cannot use --block option in symmetric_run.") return cleaned_args @@ -83,6 +99,7 @@ def curate_and_validate_ray_start_args(run_and_start_args: List[str]) -> List[st @click.command( name="symmetric_run", + cls=SymmetricRunCommand, context_settings={"ignore_unknown_options": True, "allow_extra_args": True}, help="""Command to start Ray across all nodes and execute an entrypoint command. @@ -135,24 +152,20 @@ def curate_and_validate_ray_start_args(run_and_start_args: List[str]) -> List[st help="If provided, wait for this number of nodes to start.", ) @click.argument("ray_args_and_entrypoint", nargs=-1, type=click.UNPROCESSED) -def symmetric_run(address, min_nodes, ray_args_and_entrypoint): - all_args = sys.argv[1:] - - if all_args and all_args[0] == "symmetric-run": - all_args = all_args[1:] - +@click.pass_context +def symmetric_run(ctx, address, min_nodes, ray_args_and_entrypoint): + raw_args = ctx.meta.get("raw_args", []) try: - separator = all_args.index("--") + separator = raw_args.index("--") except ValueError: raise click.ClickException( "No separator '--' found in arguments. Please use '--' to " "separate Ray start arguments and the entrypoint command." + f" Got arguments: {raw_args}" ) - run_and_start_args, entrypoint_on_head = ( - all_args[:separator], - all_args[separator + 1 :], - ) + run_and_start_args = raw_args[:separator] + entrypoint_on_head = raw_args[separator + 1 :] ray_start_args = curate_and_validate_ray_start_args(run_and_start_args) @@ -219,7 +232,7 @@ def symmetric_run(address, min_nodes, ray_args_and_entrypoint): subprocess.run(ray_start_cmd, check=True, capture_output=True) click.echo("Head node started.") click.echo("=======================") - if min_nodes > 1 and not check_cluster_ready(min_nodes): + if min_nodes > 1 and not check_cluster_ready(address, min_nodes): raise click.ClickException( "Timed out waiting for other nodes to start." ) diff --git a/python/ray/tests/symmetric_run_test_entrypoint.py b/python/ray/tests/symmetric_run_test_entrypoint.py new file mode 100644 index 000000000000..0e0936315292 --- /dev/null +++ b/python/ray/tests/symmetric_run_test_entrypoint.py @@ -0,0 +1,24 @@ +"""Test entrypoint for symmetric_run multi-node integration test. + +This script is executed by symmetric_run on the head node during +test_symmetric_run_three_node_cluster_simulated. It connects to the +Ray cluster and verifies that all expected nodes have joined. +""" +import sys +import time + +import ray + +EXPECTED_NODES = 3 +TIMEOUT_SECONDS = 60 + +ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False) + +for _ in range(TIMEOUT_SECONDS): + if len(ray.nodes()) >= EXPECTED_NODES: + print("ENTRYPOINT_SUCCESS") + sys.exit(0) + time.sleep(1) + +print(f"ENTRYPOINT_FAILED: cluster_size={len(ray.nodes())}") +sys.exit(1) diff --git a/python/ray/tests/symmetric_run_wrapper.py b/python/ray/tests/symmetric_run_wrapper.py new file mode 100644 index 000000000000..55802d111159 --- /dev/null +++ b/python/ray/tests/symmetric_run_wrapper.py @@ -0,0 +1,43 @@ +"""Wrapper to run symmetric_run with mocked network interfaces. + +This wrapper is needed because symmetric_run spawns subprocesses (ray start, +entrypoint), so mocks applied in pytest don't propagate. This wrapper applies +patches before invoking symmetric_run, allowing us to control what each +'node' sees as its local IP. + +If MOCK_HIDE_IP is set, filter out this IP from network interfaces so the +process thinks it's not on the head node. +""" +import os +from unittest.mock import patch + +import ray # noqa: F401 - must import ray first, psutil is vendored +from ray.scripts.symmetric_run import symmetric_run + +import psutil + +_real_net_if_addrs = psutil.net_if_addrs + + +def _mocked_net_if_addrs(): + """Return network interfaces, optionally hiding the head IP.""" + real_addrs = _real_net_if_addrs() + hide_ip = os.environ.get("MOCK_HIDE_IP") + if not hide_ip: + return real_addrs + + new_addrs = {} + for iface, addrs in real_addrs.items(): + # Filter out the IP to be hidden from the list of addresses for this interface. + filtered_addrs = [addr for addr in addrs if addr.address != hide_ip] + if filtered_addrs: + new_addrs[iface] = filtered_addrs + return new_addrs + + +if __name__ == "__main__": + with ( + patch("ray._private.services.find_gcs_addresses", return_value=[]), + patch("psutil.net_if_addrs", side_effect=_mocked_net_if_addrs), + ): + symmetric_run() diff --git a/python/ray/tests/test_symmetric_run.py b/python/ray/tests/test_symmetric_run.py index 6e84dc1b3d2a..a4ac0a32c9a4 100644 --- a/python/ray/tests/test_symmetric_run.py +++ b/python/ray/tests/test_symmetric_run.py @@ -1,16 +1,48 @@ +import os +import signal +import subprocess import sys +import time from contextlib import contextmanager -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from click.testing import CliRunner import ray -import ray.scripts.scripts as scripts +from ray._common.test_utils import wait_for_condition +from ray._private.test_utils import get_current_unused_port + + +def _get_non_loopback_ip() -> str: + import socket + + import psutil + + for addrs in psutil.net_if_addrs().values(): + for addr in addrs: + if addr.family == socket.AF_INET and not addr.address.startswith("127."): + return addr.address + return "127.0.0.1" + + +def _kill_process_group(proc: subprocess.Popen) -> None: + """Best-effort kill of a process and its children.""" + if proc.poll() is not None: + return + try: + os.killpg(proc.pid, signal.SIGTERM) + proc.wait(timeout=5) + except (OSError, subprocess.TimeoutExpired): + try: + os.killpg(proc.pid, signal.SIGKILL) + except OSError: + # The process may have already died. + pass @contextmanager -def _setup_mock_network_utils(curr_ip, head_ip): +def _setup_mock_network_utils(curr_ip: str, head_ip: str): import socket # Mock socket.getaddrinfo to return a valid IP @@ -33,14 +65,13 @@ def _setup_mock_network_utils(curr_ip, head_ip): @pytest.fixture def cleanup_ray(): - """Shutdown all ray instances""" + """Shutdown all ray instances.""" yield - runner = CliRunner() - runner.invoke(scripts.stop, ["--force"]) + subprocess.run(["ray", "stop", "--force"], capture_output=True) ray.shutdown() -def test_symmetric_run_basic_interface(monkeypatch, cleanup_ray): +def test_symmetric_run_basic_interface(cleanup_ray): """Test basic symmetric_run interface with minimal arguments.""" from ray.scripts.symmetric_run import symmetric_run @@ -52,10 +83,9 @@ def test_symmetric_run_basic_interface(monkeypatch, cleanup_ray): with _setup_mock_network_utils("127.0.0.1", "127.0.0.1"): args = ["--address", "127.0.0.1:6379", "--", "echo", "test"] - with patch("sys.argv", ["/bin/ray", "symmetric-run", *args]): - # Test basic symmetric_run call using CliRunner - result = runner.invoke(symmetric_run, args) - assert result.exit_code == 0 + # Test basic symmetric_run call using CliRunner + result = runner.invoke(symmetric_run, args) + assert result.exit_code == 0 # Verify that subprocess.run was called for ray start assert mock_run.called @@ -74,7 +104,7 @@ def test_symmetric_run_basic_interface(monkeypatch, cleanup_ray): assert len(ray_stop_calls) > 0 -def test_symmetric_run_worker_node_behavior(monkeypatch, cleanup_ray): +def test_symmetric_run_worker_node_behavior(cleanup_ray): """Test symmetric_run behavior when not on the head node.""" from ray.scripts.symmetric_run import symmetric_run @@ -84,97 +114,204 @@ def test_symmetric_run_worker_node_behavior(monkeypatch, cleanup_ray): mock_run.return_value.returncode = 0 with _setup_mock_network_utils("192.168.1.100", "192.168.1.101"): - # Mock socket connection check to simulate head node ready - with patch("socket.socket") as mock_socket: - mock_socket_instance = MagicMock() - mock_socket_instance.connect_ex.return_value = 0 - mock_socket.return_value.__enter__.return_value = mock_socket_instance + # Pretend head is ready so worker proceeds. + with patch("ray.scripts.symmetric_run.check_head_node_ready") as ready: + ready.return_value = True - # Test worker node behavior args = ["--address", "192.168.1.100:6379", "--", "echo", "test"] - with patch("sys.argv", ["/bin/ray", "symmetric-run", *args]): - with patch( - "ray.scripts.symmetric_run.check_head_node_ready" - ) as mock_check_head_node_ready: - mock_check_head_node_ready.return_value = True - result = runner.invoke(symmetric_run, args) - assert result.exit_code == 0 - - # Verify that subprocess.run was called - assert mock_run.called - calls = mock_run.call_args_list - - # Should have called ray start with --address (worker mode) - ray_start_calls = [ - call - for call in calls - if "ray" in str(call) and "start" in str(call) - ] - assert len(ray_start_calls) > 0 + result = runner.invoke(symmetric_run, args) + assert result.exit_code == 0 - # Check that it's in worker mode (--address instead of --head) - start_call = ray_start_calls[0] - start_args = start_call[0][0] - assert "--address" in start_args - assert "192.168.1.100:6379" in start_args - assert "--head" not in start_args - assert "--block" in start_args # Worker nodes should block + # Verify that subprocess.run was called + assert mock_run.called + calls = mock_run.call_args_list + + # Should have called ray start with --address (worker mode) + ray_start_calls = [ + call for call in calls if "ray" in str(call) and "start" in str(call) + ] + assert len(ray_start_calls) > 0 + + # Check that it's in worker mode (--address instead of --head) + start_call = ray_start_calls[0] + start_args = start_call[0][0] + assert "--address" in start_args + assert "192.168.1.100:6379" in start_args + assert "--head" not in start_args + assert "--block" in start_args # Worker nodes should block -def test_symmetric_run_arg_validation(monkeypatch, cleanup_ray): +def test_symmetric_run_arg_validation(cleanup_ray): """Test that symmetric_run validates arguments.""" from ray.scripts.symmetric_run import symmetric_run runner = CliRunner() - # Mock subprocess.run to avoid actually starting Ray with _setup_mock_network_utils("127.0.0.1", "127.0.0.1"): - + # Mock subprocess.run to avoid actually starting Ray with patch("subprocess.run") as mock_run: mock_run.return_value.returncode = 0 args = ["--address", "127.0.0.1:6379", "--", "echo", "test"] - with patch("sys.argv", ["/bin/ray", "symmetric-run", *args]): - # Test basic symmetric_run call using CliRunner - result = runner.invoke(symmetric_run, args) - assert result.exit_code == 0 + result = runner.invoke(symmetric_run, args) + assert result.exit_code == 0 # Test that invalid arguments are rejected with patch("subprocess.run") as mock_run: mock_run.return_value.returncode = 0 args = ["--address", "127.0.0.1:6379", "echo", "test"] - with patch("sys.argv", ["/bin/ray", "symmetric-run", *args]): - result = runner.invoke(symmetric_run, args) - assert result.exit_code == 1 - assert "No separator" in result.output + result = runner.invoke(symmetric_run, args) + assert result.exit_code == 1 + assert "No separator" in result.output # Test that invalid arguments are rejected with patch("subprocess.run") as mock_run: mock_run.return_value.returncode = 0 args = ["--address", "127.0.0.1:6379", "--head", "--", "echo", "test"] - with patch("sys.argv", ["/bin/ray", "symmetric-run", *args]): - result = runner.invoke(symmetric_run, args) - assert result.exit_code == 1 - assert "Cannot use --head option in symmetric_run." in result.output + result = runner.invoke(symmetric_run, args) + assert result.exit_code == 1 + assert "Cannot use --head option" in result.output with patch("subprocess.run") as mock_run: mock_run.return_value.returncode = 0 # Test args with "=" are passed to ray start args = ["--address", "127.0.0.1:6379", "--num-cpus=4", "--", "echo", "test"] - with patch("sys.argv", ["/bin/ray", "symmetric-run", *args]): - result = runner.invoke(symmetric_run, args) - assert result.exit_code == 0 + result = runner.invoke(symmetric_run, args) + assert result.exit_code == 0 - ray_start_calls = [ - call - for call in mock_run.call_args_list - if "ray" in str(call) and "start" in str(call) - ] - assert len(ray_start_calls) > 0 - assert "--num-cpus=4" in ray_start_calls[0][0][0] + ray_start_calls = [ + call + for call in mock_run.call_args_list + if "ray" in str(call) and "start" in str(call) + ] + assert len(ray_start_calls) > 0 + assert "--num-cpus=4" in ray_start_calls[0][0][0] + + +class PortAllocator: + """Allocate unique ports for a test, avoiding collisions.""" + + def __init__(self): + self._allocated: set[int] = set() + + def allocate(self) -> int: + for _ in range(100): + port = get_current_unused_port() + if port not in self._allocated: + self._allocated.add(port) + return port + raise RuntimeError("Could not allocate unique port after 100 attempts") + + def build_ray_port_args(self) -> list[str]: + """Build Ray port arguments for a single node (needs unique ports per node).""" + worker_ports = ",".join(str(self.allocate()) for _ in range(5)) + return [ + f"--node-manager-port={self.allocate()}", + f"--object-manager-port={self.allocate()}", + f"--dashboard-port={self.allocate()}", + f"--worker-port-list={worker_ports}", + "--disable-usage-stats", + "--num-cpus=0", + ] + + +def test_symmetric_run_three_node_cluster_simulated(cleanup_ray): + """Simulate multi-node symmetric_run on a single machine.""" + + # Clean slate (important when iterating locally). + subprocess.run(["ray", "stop", "--force"], capture_output=True) + + ports = PortAllocator() + + # We use the REAL IP for the head, so actual Ray processes can bind to it. + # For workers, we mock psutil to HIDE this IP so they think they are remote. + head_ip = _get_non_loopback_ip() + gcs_port = ports.allocate() + address = f"{head_ip}:{gcs_port}" + + base_env = os.environ.copy() + base_env["RAY_ENABLE_WINDOWS_OR_OSX_CLUSTER"] = "1" + base_env["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0" + + test_dir = os.path.dirname(__file__) + wrapper_script = os.path.join(test_dir, "symmetric_run_wrapper.py") + entrypoint_script = os.path.join(test_dir, "symmetric_run_test_entrypoint.py") + + # Symmetric commands on all nodes + def build_cmd(port_args: list[str]) -> list[str]: + return [ + sys.executable, + wrapper_script, + "--address", + address, + "--min-nodes", + "3", + *port_args, + "--", + sys.executable, + entrypoint_script, + ] + + head_cmd = build_cmd(ports.build_ray_port_args()) + worker_cmds = [build_cmd(ports.build_ray_port_args()) for _ in range(2)] + + head_env = {**base_env, "RAY_ADDRESS": address} + # So that workers don't think they have the head node's IP. + worker_env = {**base_env, "MOCK_HIDE_IP": head_ip} + + worker_procs = [] + head_proc = None + try: + head_proc = subprocess.Popen( + head_cmd, + env=head_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + start_new_session=True, + ) + + # Wait for head to start listening before starting workers. + def _check_head_ready() -> bool: + import socket + + host, port_str = address.split(":") + try: + with socket.create_connection((host, int(port_str)), timeout=1): + return True + except (socket.timeout, ConnectionRefusedError, OSError): + return False + + wait_for_condition(_check_head_ready, timeout=30, retry_interval_ms=250) + + for cmd in worker_cmds: + worker_procs.append( + subprocess.Popen( + cmd, + env=worker_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + start_new_session=True, + ) + ) + time.sleep(2) + + head_stdout_b, head_stderr_b = head_proc.communicate(timeout=150) + head_stdout = head_stdout_b.decode("utf-8", "replace") + head_stderr = head_stderr_b.decode("utf-8", "replace") + head_rc = head_proc.returncode + + finally: + for p in worker_procs: + _kill_process_group(p) + if head_proc: + _kill_process_group(head_proc) + + assert head_rc == 0, f"Head failed: rc={head_rc}\nstderr:\n{head_stderr[-2000:]}" + assert "On head node" in head_stdout + assert "ENTRYPOINT_SUCCESS" in head_stdout if __name__ == "__main__":