Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
316d43e
Miscellaneous infra.
pjin-nvidia Nov 13, 2025
dd953b0
Ray utils.
pjin-nvidia Nov 13, 2025
5717d7a
No cover.
pjin-nvidia Nov 13, 2025
4ecd8d3
Remove DEBUG. Comment.
pjin-nvidia Nov 15, 2025
8103dbf
Comment about ray package extra.
pjin-nvidia Nov 15, 2025
dc493d5
The.
pjin-nvidia Nov 15, 2025
f9e5d8f
Merge remote-tracking branch 'origin/main' into pjin/misc-infra
pjin-nvidia Nov 15, 2025
9502d82
Fix test (?).
pjin-nvidia Nov 15, 2025
0475d5e
Initial support for server pyproject.toml (WIP).
pjin-nvidia Nov 15, 2025
d86756b
Fix pyproject.toml check.
pjin-nvidia Nov 15, 2025
79028a6
Working directory Path.
pjin-nvidia Nov 15, 2025
7e62b1d
Install a server venv from pyproject.toml if available.
pjin-nvidia Nov 15, 2025
36efb94
Deprecated vllm_model requirements.txt.
pjin-nvidia Nov 15, 2025
8d49b95
Consistently use dashes in package names.
pjin-nvidia Nov 15, 2025
6fb0a95
Lint.
pjin-nvidia Nov 15, 2025
7231efa
Cleanup.
pjin-nvidia Nov 15, 2025
8fc0d9d
VLLM server spinup.
pjin-nvidia Nov 15, 2025
8975e98
VLLM server host and port.
pjin-nvidia Nov 15, 2025
51ba6fc
Allocate the free port for VLLM in the model server process.
pjin-nvidia Nov 16, 2025
aa97796
Type.
pjin-nvidia Nov 16, 2025
6ec9325
Fix for pyproject.toml (this works lol).
pjin-nvidia Nov 16, 2025
33ec3f9
VLLM server "routing" (just re-using the existing multiple clients).
pjin-nvidia Nov 16, 2025
77cda85
Better order.
pjin-nvidia Nov 16, 2025
44dcee1
Merge remote-tracking branch 'origin/main' into pjin/ray-utils
pjin-nvidia Nov 16, 2025
a85f4f0
WIP.
pjin-nvidia Nov 16, 2025
7201c8f
Comment.
pjin-nvidia Nov 16, 2025
834d9b9
Default to "mp" backend.
pjin-nvidia Nov 16, 2025
5ee8b57
Cleanup.
pjin-nvidia Nov 16, 2025
10b5295
Cleanup.
pjin-nvidia Nov 16, 2025
e4c5573
Non-async VLLM server heartbeat to avoid early asyncio event loop.
pjin-nvidia Nov 16, 2025
0a8da20
With pyproject.toml, no pre-install command needed.
pjin-nvidia Nov 16, 2025
85a09fe
Ray GPU node-related global config keys. Simplified spinup (WIP).
pjin-nvidia Nov 16, 2025
ad0e2fc
Improved server venv pyproject install that does not use editable.
pjin-nvidia Nov 17, 2025
5c1fe99
Querying ray state to find nodes with available and unused GPUs.
pjin-nvidia Nov 17, 2025
f32957e
Only use explicitly reserved ray GPU nodes if specified.
pjin-nvidia Nov 17, 2025
ef77c4c
Comment. Cleanup.
pjin-nvidia Nov 17, 2025
bbf4631
Cleanup.
pjin-nvidia Nov 17, 2025
531a61d
Type.
pjin-nvidia Nov 17, 2025
f88ec6a
No cover.
pjin-nvidia Nov 17, 2025
d819740
Type.
pjin-nvidia Nov 17, 2025
7640773
Rename reserved => allowed.
pjin-nvidia Nov 17, 2025
0436b47
Packaging and setup.
pjin-nvidia Nov 17, 2025
70670a2
Rename.
pjin-nvidia Nov 17, 2025
e61253c
VLLMModel local spinup (originally from PR #317).
pjin-nvidia Nov 17, 2025
854609f
Revert VLLMModel changes (moving to PR #318).
pjin-nvidia Nov 17, 2025
dc6ffef
One line uv pip install.
pjin-nvidia Nov 18, 2025
56b9bfa
VLLM spinup in a Ray worker.
pjin-nvidia Nov 20, 2025
e8afd2d
Print the names of servers yet to have finished spinning up.
pjin-nvidia Nov 20, 2025
0142784
Formatting.
pjin-nvidia Nov 20, 2025
04a97dd
Import.
pjin-nvidia Nov 20, 2025
70ac196
Do not count resources of ray actors in 'DEAD' state (these resources…
pjin-nvidia Nov 20, 2025
3e5c924
Support for specifying non-anonymous Ray namespace.
pjin-nvidia Nov 26, 2025
8bdcec0
Fix for starting nested Ray actors.
pjin-nvidia Nov 26, 2025
a0c0d19
Merge remote-tracking branch 'origin/main' into pjin/misc-infra
pjin-nvidia Nov 26, 2025
17f640f
Merge remote-tracking branch 'origin/main' into pjin/ray-utils
pjin-nvidia Nov 26, 2025
0a94c2d
Merge remote-tracking branch 'origin/main' into pjin/misc-infra
pjin-nvidia Dec 1, 2025
d4b8074
Merge remote-tracking branch 'origin/main' into pjin/ray-utils
pjin-nvidia Dec 1, 2025
8fe389f
Matching the misc infra PR.
pjin-nvidia Dec 1, 2025
613efb4
No cover.
pjin-nvidia Dec 1, 2025
7575eb6
Global scheduling helper to track free GPUs of schedulable ray nodes.
pjin-nvidia Dec 2, 2025
d7e1683
Rename.
pjin-nvidia Dec 2, 2025
f7c1937
Print.
pjin-nvidia Dec 2, 2025
2d37d17
Avoid an unnecessary ray import.
pjin-nvidia Dec 2, 2025
a35f58d
Try to pass the linter.
pjin-nvidia Dec 2, 2025
1b53089
Test.
pjin-nvidia Dec 2, 2025
6327760
Tests.
pjin-nvidia Dec 2, 2025
f5466f9
Fix test.
pjin-nvidia Dec 2, 2025
7a7e952
Fix test.
pjin-nvidia Dec 2, 2025
eab68a0
Unfix test.
pjin-nvidia Dec 2, 2025
66b788d
Revert to just cd into working dir.
pjin-nvidia Dec 2, 2025
a78f226
Deduplicate.
pjin-nvidia Dec 2, 2025
fdb54fe
Also add explicit check for requirements.txt.
pjin-nvidia Dec 2, 2025
3fb2911
Revert format.
pjin-nvidia Dec 2, 2025
d62ab6c
VLLMModel refresh.
pjin-nvidia Dec 2, 2025
7809170
Add vllm_model pyproject.toml (depends on PR #317).
pjin-nvidia Dec 3, 2025
156f039
Unpin vllm version.
pjin-nvidia Dec 3, 2025
21ba79e
Consolidated ray actor env vars setup.
pjin-nvidia Dec 3, 2025
74acc72
Fix.
pjin-nvidia Dec 4, 2025
6c59909
Fix.
pjin-nvidia Dec 4, 2025
b595ce2
Format.
pjin-nvidia Dec 5, 2025
93c0273
Pin ray version to NeMo RL version (branch: yifu/nemotron).
pjin-nvidia Dec 8, 2025
7580f50
Pick misc infra ray extras (fix test).
pjin-nvidia Dec 8, 2025
bf0ccfe
Use a scheduling coordination helper.
pjin-nvidia Dec 9, 2025
f655a8c
Merge remote-tracking branch 'origin/main' into pjin/ray-utils
pjin-nvidia Dec 9, 2025
b99a5c4
Merge remote-tracking branch 'origin/main' into pjin/misc-infra
pjin-nvidia Dec 9, 2025
fd98595
Sync vllm_model pyproject.toml.
pjin-nvidia Dec 9, 2025
de2dd33
Sync vllm_model pyproject.toml.
pjin-nvidia Dec 9, 2025
7d399f3
This is just a list of node IDs (as of RL commit: 07a71f7b1656adb99f6…
pjin-nvidia Dec 9, 2025
017be87
Merge remote-tracking branch 'origin/pjin/misc-infra' into pjin/nemot…
pjin-nvidia Dec 9, 2025
310effe
Merge remote-tracking branch 'origin/pjin/ray-utils' into pjin/nemotr…
pjin-nvidia Dec 9, 2025
5a8bb0c
Minimum version of vllm >= 0.11.2.
pjin-nvidia Dec 9, 2025
35d5445
Un-pin ray version.
pjin-nvidia Dec 9, 2025
3f50cf1
Merge remote-tracking branch 'origin/pjin/ray-utils' into pjin/nemotr…
pjin-nvidia Dec 9, 2025
330d9af
Pin vllm version to the nemo RL version.
pjin-nvidia Dec 9, 2025
474608a
Merge remote-tracking branch 'origin/pjin/nano-v3-main-dev-20251207' …
pjin-nvidia Dec 9, 2025
987cf5c
Minimum version of vllm >= 0.11.2.
pjin-nvidia Dec 9, 2025
9058505
Fix for recent VLLM.
pjin-nvidia Dec 10, 2025
ff120a7
Merge remote-tracking branch 'origin/pjin/misc-infra' into pjin/nemot…
pjin-nvidia Dec 10, 2025
a0b6507
Merge remote-tracking branch 'origin/pjin/ray-utils' into pjin/nemotr…
pjin-nvidia Dec 10, 2025
a119214
Pick fix from https://github.com/NVIDIA-NeMo/Gym/pull/359 and revert …
pjin-nvidia Dec 10, 2025
9ef1475
Merge remote-tracking branch 'origin/pjin/nano-v3-main-dev-20251207' …
pjin-nvidia Dec 10, 2025
7e5aa85
Merge remote-tracking branch 'origin/main' into pjin/nano-v3-main-dev…
pjin-nvidia Dec 10, 2025
923b891
Merge remote-tracking branch 'origin/main' into pjin/nemotron-ray-dev…
pjin-nvidia Dec 10, 2025
74d8bb8
Merge remote-tracking branch 'origin/main' into pjin/nano-v3-main-dev…
pjin-nvidia Dec 11, 2025
abad135
Rollect collection iterator returns both the input row and output res…
pjin-nvidia Dec 11, 2025
35e18c2
Type annotation.
pjin-nvidia Dec 11, 2025
25996f8
Pin vllm == 0.11.2.
pjin-nvidia Dec 11, 2025
e24e3af
Merge remote-tracking branch 'origin/pjin/nano-v3-main-dev-20251207' …
pjin-nvidia Dec 11, 2025
54c6812
Merge remote-tracking branch 'origin/main' into pjin/nemotron-ray-dev…
pjin-nvidia Dec 11, 2025
9625d9d
add genrm rlhf
HeyyyyyyG Dec 15, 2025
73f95ea
Merge remote-tracking branch 'origin/pjin/nemotron-ray-dev-20251208' …
HeyyyyyyG Dec 16, 2025
6e93c21
Re-pin ray.
pjin-nvidia Dec 17, 2025
10af204
add long context env
fayejf Dec 17, 2025
8226f5c
fix
fayejf Dec 17, 2025
c524692
Rename lc => lc_judge.
pjin-nvidia Dec 17, 2025
9c88189
Revert swallowing judge model endpoint errors.
pjin-nvidia Dec 17, 2025
44c2e15
Merge remote-tracking branch 'origin/fjia/lc_judge' into pjin/nemotro…
pjin-nvidia Dec 18, 2025
6f29347
Fix missing config field.
pjin-nvidia Dec 18, 2025
f50ec7c
Merge remote-tracking branch 'github/fjia/lc_judge' into pjin/nemotro…
pjin-nvidia Dec 18, 2025
5e36ca1
Merge remote-tracking branch 'origin/main' into pjin/nemotron-ray-dev…
pjin-nvidia Dec 19, 2025
e5fd045
Cherrypick: fix args bugs.
pjin-nvidia Dec 20, 2025
63a13eb
Fallback in case json.loads fails (???).
pjin-nvidia Dec 21, 2025
d4878bc
fix genrm comparison strategy and calendar env
pjin-nvidia Dec 22, 2025
8e28efc
Log server stdout/err to logfiles in the server working dirs.
pjin-nvidia Dec 30, 2025
72242eb
Block off more port ranges to avoid potential conflicts.
pjin-nvidia Dec 30, 2025
f706463
Backport https://github.com/NVIDIA-NeMo/Gym/pull/552.
pjin-nvidia Jan 6, 2026
dc048a3
Fix for ray state API usage (list_nodes with limit).
pjin-nvidia Jan 27, 2026
6734cfb
VLLMModel spinup support for setting env vars.
pjin-nvidia Jan 27, 2026
4ec3002
Redirect server stdout/stderr not to files by default. Create unique …
pjin-nvidia Feb 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions nemo_gym/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from subprocess import Popen
from threading import Thread
from time import sleep, time
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import psutil
import rich
Expand All @@ -49,6 +49,9 @@
GlobalConfigDictParserConfig,
get_global_config_dict,
)
from nemo_gym.ray_utils import (
_start_global_ray_gpu_scheduling_helper,
)
from nemo_gym.server_status import StatusCommand
from nemo_gym.server_utils import (
HEAD_SERVER_KEY_NAME,
Expand All @@ -60,10 +63,15 @@
)


def _setup_env_command(dir_path: Path, global_config_dict: DictConfig) -> str: # pragma: no cover
def _setup_env_command(dir_path: Path, global_config_dict: DictConfig, top_level_name: Optional[str] = None) -> str: # pragma: no cover
head_server_deps = global_config_dict[HEAD_SERVER_DEPS_KEY_NAME]

uv_venv_cmd = f"uv venv --seed --allow-existing --python {global_config_dict[PYTHON_VERSION_KEY_NAME]} .venv"
if top_level_name is not None:
venv = f".venv-{top_level_name}"
else:
venv = ".venv"

uv_venv_cmd = f"uv venv --seed --allow-existing --python {global_config_dict[PYTHON_VERSION_KEY_NAME]} {venv}"

has_pyproject_toml = exists(f"{dir_path / 'pyproject.toml'}")
has_requirements_txt = exists(f"{dir_path / 'requirements.txt'}")
Expand All @@ -74,29 +82,49 @@ def _setup_env_command(dir_path: Path, global_config_dict: DictConfig) -> str:
)
elif has_pyproject_toml:
install_cmd = f"""uv pip install '-e .' {" ".join(head_server_deps)}"""
if dir_path.name == "vllm_model":
# NB: --no-deps is a workaround for installing vllm (current version: 0.11.2) on a cpu target,
# b/c `uv pip install` resolves dependencies differently vs `pip install`.
install_cmd = f"""uv pip install --no-deps 'vllm==0.11.2' && {install_cmd}"""
elif has_requirements_txt:
install_cmd = f"""uv pip install -r requirements.txt {" ".join(head_server_deps)}"""
else:
raise RuntimeError(f"Missing pyproject.toml or requirements.txt for uv venv setup in server dir: {dir_path}")

if top_level_name is not None:
uv_venv_cmd = f"{uv_venv_cmd} > >(sed 's/^/({top_level_name}) /') 2> >(sed 's/^/({top_level_name}) /' >&2)"
install_cmd = f"{install_cmd} > >(sed 's/^/({top_level_name}) /') 2> >(sed 's/^/({top_level_name}) /' >&2)"

cmd = f"""cd {dir_path} \\
&& {uv_venv_cmd} \\
&& source .venv/bin/activate \\
&& source {venv}/bin/activate \\
&& {install_cmd} \\
"""

return cmd


def _run_command(command: str, working_dir_path: Path) -> Popen: # pragma: no cover
def _run_command(command: str, working_dir_path: Path, top_level_name: Optional[str] = None) -> Popen: # pragma: no cover
work_dir = f"{working_dir_path.absolute()}"
custom_env = environ.copy()
py_path = custom_env.get("PYTHONPATH", None)
if py_path is not None:
custom_env["PYTHONPATH"] = f"{work_dir}:{py_path}"
else:
custom_env["PYTHONPATH"] = work_dir
return Popen(command, executable="/bin/bash", shell=True, env=custom_env)
redirect_stdout = sys.stdout
redirect_stderr = sys.stderr
if top_level_name is not None:
redirect_stdout = open(f"{work_dir}/run-{top_level_name}.out.log", "a")
redirect_stderr = open(f"{work_dir}/run-{top_level_name}.err.log", "a")
return Popen(
command,
executable="/bin/bash",
shell=True,
env=custom_env,
stdout=redirect_stdout,
stderr=redirect_stderr,
)


class RunConfig(BaseNeMoGymCLIConfig):
Expand Down Expand Up @@ -152,6 +180,7 @@ class RunHelper: # pragma: no cover
_head_server: uvicorn.Server
_head_server_thread: Thread
_head_server_instance: HeadServer
_head_ray_gpu_helper: Any

_processes: Dict[str, Popen]
_server_instance_display_configs: List[ServerInstanceDisplayConfig]
Expand All @@ -164,6 +193,8 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)
# Note: This function will modify the global config dict - update `ray_head_node_address`
initialize_ray()

self._head_ray_gpu_helper = _start_global_ray_gpu_scheduling_helper()

# Assume Nemo Gym Run is for a single agent.
escaped_config_dict_yaml_str = shlex.quote(OmegaConf.to_yaml(global_config_dict))

Expand Down Expand Up @@ -201,7 +232,7 @@ def start(self, global_config_dict_parser_config: GlobalConfigDictParserConfig)

dir_path = PARENT_DIR / Path(first_key, second_key)

command = f"""{_setup_env_command(dir_path, global_config_dict)} \\
command = f"""{_setup_env_command(dir_path, global_config_dict, top_level_path)} \\
&& {NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME}={escaped_config_dict_yaml_str} \\
{NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME}={shlex.quote(top_level_path)} \\
python {str(entrypoint_fpath)}"""
Expand Down
150 changes: 150 additions & 0 deletions nemo_gym/comparison_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Comparison strategies for multi-generation reward computation.
"""
import hashlib
import json
from typing import Any, Dict, List, Optional, Protocol, Tuple, runtime_checkable

from pydantic import BaseModel, Field

from nemo_gym.server_utils import ServerClient, raise_for_status


@runtime_checkable
class ComparisonStrategy(Protocol):
"""Protocol for comparison strategies that compute rewards from multiple generations."""

agent_names: List[str]
num_generations_per_prompt: int
policy_model_server_name: str

async def compare(
self,
conversation_history: List[Dict[str, str]],
responses: List[str],
server_client: ServerClient,
principle: Optional[str] = None,
) -> Tuple[List[float], Dict[str, float]]:
"""Compare N responses and return (rewards, metrics)."""
...


class GenRMStrategyConfig(BaseModel):
"""Configuration for GenRM comparison strategy."""
agent_names: List[str] = Field(default_factory=lambda: ["genrm_simple_agent"])
num_generations_per_prompt: int = 16
genrm_compare_server_name: str = "genrm_compare"
policy_model_server_name: str = "policy_model"


class GenRMStrategy:
"""GenRM comparison strategy using pairwise comparisons."""

def __init__(self, config: GenRMStrategyConfig):
self.config = config
self.agent_names = config.agent_names
self.num_generations_per_prompt = config.num_generations_per_prompt
self.policy_model_server_name = config.policy_model_server_name

async def compare(
self,
conversation_history: List[Dict[str, str]],
response_objs: List[Dict],
server_client: ServerClient,
principle: Optional[str] = None,
) -> Tuple[List[float], Dict[str, float]]:
"""Call genrm_compare server to get rewards for each response.

Args:
conversation_history: The conversation context
response_objs: List of raw Response API objects
server_client: The server client for making requests
principle: Optional principle for principle-based GenRM comparison

Returns:
Tuple of (rewards, metrics) from GenRM comparison
"""
payload = {
"conversation_history": conversation_history,
"response_objs": response_objs,
}

if principle is not None:
payload["principle"] = principle

res = await server_client.post(
server_name=self.config.genrm_compare_server_name,
url_path="/compare",
json=payload,
)
await raise_for_status(res)
result = await res.json()

rewards = result.get("rewards", [0.0] * len(response_objs))
metrics = result.get("metrics", {})

return rewards, metrics


def get_prompt_key(example: Dict) -> str:
"""Get stable key for grouping examples by prompt and principle.

Examples with the same conversation history but different principles
should be in separate groups, so we include principle in the hash.
"""
if "prompt_id" in example:
# If prompt_id exists, combine it with principle for uniqueness
prompt_id = str(example["prompt_id"])
principle = example.get("principle")
if principle is not None:
return f"{prompt_id}::{principle}"
return prompt_id

# Hash both conversation history and principle together
conv = extract_conversation_history(example)
principle = example.get("principle")
key_data = {
"conversation": conv,
"principle": principle,
}
return hashlib.sha256(json.dumps(key_data, sort_keys=True).encode()).hexdigest()


def extract_conversation_history(example: Dict) -> List[Dict]:
"""Extract conversation history from example.

Gym examples store history in responses_create_params.input
"""
responses_create_params = example.get("responses_create_params")
if responses_create_params is None:
raise ValueError(f"Example missing 'responses_create_params': {list(example.keys())}")
if "input" not in responses_create_params:
raise ValueError(f"responses_create_params missing 'input': {list(responses_create_params.keys())}")
return responses_create_params["input"]


def extract_generated_text(gen_result: Dict) -> str:
"""Extract generated text from generation result."""
if not isinstance(gen_result, dict):
raise ValueError(f"Expected dict, got {type(gen_result)}")
if "output" in gen_result:
output = gen_result["output"]
if isinstance(output, list) and output:
return output[0].get("content", "")
if isinstance(output, str):
return output
if "response" in gen_result:
return gen_result["response"]
raise ValueError(f"Cannot extract generated text from: {list(gen_result.keys())}")


async def generate_response(example: Dict, server_client: ServerClient, model_server: str) -> Dict:
"""Generate a single response using the policy model."""
params = example.get("responses_create_params")
if params is None:
raise ValueError(f"Example missing 'responses_create_params': {list(example.keys())}")
res = await server_client.post(server_name=model_server, url_path="/v1/responses", json=params)
await raise_for_status(res)
return await res.json()
1 change: 1 addition & 0 deletions nemo_gym/config_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class Domain(str, Enum):
GAMES = "games"
TRANSLATION = "translation"
E2E = "e2e"
RLHF = "rlhf"
OTHER = "other"


Expand Down
18 changes: 17 additions & 1 deletion nemo_gym/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
DISALLOWED_PORTS_KEY_NAME = "disallowed_ports"
HEAD_SERVER_DEPS_KEY_NAME = "head_server_deps"
PYTHON_VERSION_KEY_NAME = "python_version"
RAY_HEAD_NODE_ADDRESS_KEY_NAME = "ray_head_node_address"
RAY_NAMESPACE_KEY_NAME = "ray_namespace"
RAY_GPU_NODES_KEY_NAME = "ray_gpu_nodes"
RAY_NUM_GPUS_PER_NODE_KEY_NAME = "ray_num_gpus_per_node"
USE_ABSOLUTE_IP = "use_absolute_ip"
NEMO_GYM_RESERVED_TOP_LEVEL_KEYS = [
CONFIG_PATHS_KEY_NAME,
Expand All @@ -54,6 +58,10 @@
DISALLOWED_PORTS_KEY_NAME,
HEAD_SERVER_DEPS_KEY_NAME,
PYTHON_VERSION_KEY_NAME,
RAY_HEAD_NODE_ADDRESS_KEY_NAME,
RAY_NAMESPACE_KEY_NAME,
RAY_GPU_NODES_KEY_NAME,
RAY_NUM_GPUS_PER_NODE_KEY_NAME,
USE_ABSOLUTE_IP,
]

Expand Down Expand Up @@ -371,11 +379,19 @@ def get_first_server_config_dict(global_config_dict: DictConfig, top_level_path:

def find_open_port(
disallowed_ports: Optional[List[int]] = None,
max_retries: int = 50,
max_retries: int = 100,
) -> int: # pragma: no cover
if disallowed_ports is None:
disallowed_ports = []

default_disallowed_ports = set(
list(range(53000, 53010+1)) +
list(range(54000, 60000+1)) +
[10001, 8265, 52365, 52365+1]
)

disallowed_ports = default_disallowed_ports | set(disallowed_ports)

# Find an open port that doesn't conflict with disallowed ports.
for _ in range(max_retries):
with socket() as s:
Expand Down
Loading