Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/hccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada

def _start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
self.zmq_port, self.listen_sock = get_free_port(self.ip)
self.zmq_port, _ = get_free_port(self.ip)

context = zmq.Context()
self.socket = context.socket(zmq.PUB)
Expand Down
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/nccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada

def _start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
self.listen_port, self.listen_sock = get_free_port(self.ip)
self.listen_port, _ = get_free_port(self.ip)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should keep sock alive here


context = zmq.Context()
self.socket = context.socket(zmq.PUB)
Expand Down
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/nixl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_agent_metadata(self) -> NixlAgentMetadata:

def start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
self.listen_port, self.listen_sock = get_free_port(self.ip)
self.listen_port, _ = get_free_port(self.ip)

context = zmq.asyncio.Context()
self.socket = context.socket(zmq.PULL)
Expand Down
21 changes: 14 additions & 7 deletions verl/utils/net_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,22 @@ def is_valid_ipv6_address(address: str) -> bool:
return False


def get_free_port(address: str) -> tuple[int, socket.socket]:
family = socket.AF_INET
if is_valid_ipv6_address(address):
family = socket.AF_INET6
def get_free_port(address: str, with_alive_sock: bool = False) -> tuple[int, socket.socket | None]:
"""Find a free port on the given address.

By default the socket is closed internally, suitable for immediate use.
Set with_alive_sock=True to keep the socket open as a port reservation,
preventing other calls from getting the same port. The caller is
responsible for closing the socket before the port is actually bound
by the target service (e.g. NCCL, uvicorn).
"""
family = socket.AF_INET6 if is_valid_ipv6_address(address) else socket.AF_INET

sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind((address, 0))

port = sock.getsockname()[1]
return port, sock
if with_alive_sock:
return port, sock
sock.close()
return port, None
45 changes: 28 additions & 17 deletions verl/workers/rollout/sglang_rollout/async_sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,19 @@ def __init__(
profiler_config = None
self.profiler_controller = DistProfiler(self.replica_rank, config=profiler_config, tool_config=tool_config)

# used for NCCL process group
if self.node_rank == 0:
# For multi-node, we need dist_init_addr so nodes can coordinate NCCL init.
# For single-node, let SGLang handle port selection internally via nccl_port,
# which also avoids port conflicts.
self._master_address = None
self._master_port = None
self._master_sock = None
if self.nnodes > 1 and self.node_rank == 0:
self._master_address = self._server_address
self._master_port, self._master_sock = get_free_port(self._server_address)
self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True)
logger.info(
f"SGLangHttpServer, replica_rank: {self.replica_rank}, "
f"master address: {self._master_address}, port: {self._master_port}"
)
else:
self._master_address = None
self._master_port = None

def get_master_address(self):
"""Get master address and port for init NCCL process group."""
Expand All @@ -145,10 +147,13 @@ def get_server_address(self):
return self._server_address, self._server_port

async def launch_server(self, master_address: str = None, master_port: int = None):
if self.node_rank != 0:
assert master_address and master_port, "non-master node should provide master address and port"
self._master_address = master_address
self._master_port = master_port
if self.nnodes > 1:
if self.node_rank != 0:
assert master_address and master_port, "non-master node should provide master address and port"
self._master_address = master_address
self._master_port = master_port
else:
self._master_sock.close()

engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {}
attention_backend = engine_kwargs.pop("attention_backend", None)
Expand All @@ -167,11 +172,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
else:
raise ValueError(f"Currently only support fp8 quantization, got: {quantization}")
dist_init_addr = (
f"[{self._master_address}]:{self._master_port}"
if is_valid_ipv6_address(self._master_address)
else f"{self._master_address}:{self._master_port}"
)
infer_tp = self.config.tensor_model_parallel_size * self.config.data_parallel_size
args = {
"model_path": self.model_config.local_path,
Expand All @@ -186,7 +186,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"ep_size": self.config.expert_parallel_size,
"node_rank": self.node_rank,
"load_format": self.config.load_format,
"dist_init_addr": dist_init_addr,
"nnodes": self.nnodes,
"trust_remote_code": self.model_config.trust_remote_code,
"max_running_requests": self.config.get("max_num_seqs", None),
Expand All @@ -202,6 +201,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
**engine_kwargs,
}

# Only set dist_init_addr for multi-node; for single-node, let SGLang
# handle port selection internally via nccl_port to avoid conflicts.
if self.nnodes > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify single/multi node, always choose free port for it?

dist_init_addr = (
f"[{self._master_address}]:{self._master_port}"
if is_valid_ipv6_address(self._master_address)
else f"{self._master_address}:{self._master_port}"
)
args["dist_init_addr"] = dist_init_addr

if self.config.prometheus.enable:
if self.config.prometheus.served_model_name:
# Extract model name from path if it's a full path
Expand Down Expand Up @@ -510,7 +519,9 @@ async def launch_servers(self):
self.servers.append(server)

# launch http server in each node
master_address, master_port = await self.servers[0].get_master_address.remote()
master_address, master_port = None, None
if self.nnodes > 1:
master_address, master_port = await self.servers[0].get_master_address.remote()
await asyncio.gather(
*[
server.launch_server.remote(master_address=master_address, master_port=master_port)
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5)

for i in range(max_retries):
try:
server_port, sock = get_free_port(server_address)
server_port, _ = get_free_port(server_address)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should keep sock alive here

app.server_args = server_args
config = uvicorn.Config(app, host=server_address, port=server_port, log_level="warning")
server = uvicorn.Server(config)
Expand Down
8 changes: 5 additions & 3 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __init__(
if self.node_rank == 0:
self._master_address = self._server_address
# used for torch.distributed.init_process_group
self._master_port, self._master_sock = get_free_port(self._server_address)
self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True)
# used for data parallel: --data-parallel-address, --data-parallel-rpc-port
self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address)
self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address)
self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address, with_alive_sock=True)
self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address, with_alive_sock=True)
else:
self._master_address = None
self._master_port = None
Expand Down Expand Up @@ -410,6 +410,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
# 3. launch server
if self.node_rank == 0:
self._master_sock.close()
self._dp_rpc_sock.close()
self._dp_master_sock.close()
await self.run_server(server_args)
else:
# TODO: avoid connect before master_sock close
Expand Down
Loading