Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions verl/checkpoint_engine/hccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self.topic = "bucket_metadata"
if self.is_master:
self._start_zmq_server()
self.dist_port, _ = get_free_port(self.ip)
self.dist_port = get_free_port(self.ip)

def prepare(self) -> MasterMetadata:
self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device="npu")
Expand Down Expand Up @@ -165,7 +165,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada

def _start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
self.zmq_port, self.listen_sock = get_free_port(self.ip)
self.zmq_port = get_free_port(self.ip)

context = zmq.Context()
self.socket = context.socket(zmq.PUB)
Expand Down
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/nccl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada

def _start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
self.listen_port, self.listen_sock = get_free_port(self.ip)
self.listen_port = get_free_port(self.ip)

context = zmq.Context()
self.socket = context.socket(zmq.PUB)
Expand Down
2 changes: 1 addition & 1 deletion verl/checkpoint_engine/nixl_checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_agent_metadata(self) -> NixlAgentMetadata:

def start_zmq_server(self):
self.ip = ray.util.get_node_ip_address().strip("[]")
self.listen_port, self.listen_sock = get_free_port(self.ip)
self.listen_port = get_free_port(self.ip)

context = zmq.asyncio.Context()
self.socket = context.socket(zmq.PULL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def launch_router_process(
timeout: int = 30,
) -> str:
router_ip = ray.util.get_node_ip_address().strip("[]")
router_port, _ = get_free_port(router_ip)
router_port = get_free_port(router_ip)
router_address = (
f"[{router_ip}]:{router_port}" if is_valid_ipv6_address(router_ip) else f"{router_ip}:{router_port}"
)
Expand Down
2 changes: 1 addition & 1 deletion verl/experimental/reward_loop/router/naive_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def launch_router_process(
worker_urls: list[str],
):
router_ip = ray.util.get_node_ip_address().strip("[]")
router_port, _ = get_free_port(router_ip)
router_port = get_free_port(router_ip)
router_address = (
f"[{router_ip}]:{router_port}" if is_valid_ipv6_address(router_ip) else f"{router_ip}:{router_port}"
)
Expand Down
32 changes: 22 additions & 10 deletions verl/utils/net_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ipaddress
import random
import socket


Expand Down Expand Up @@ -70,15 +71,26 @@ def is_valid_ipv6_address(address: str) -> bool:
return False


def get_free_port(address: str) -> tuple[int, socket.socket]:
family = socket.AF_INET
if is_valid_ipv6_address(address):
family = socket.AF_INET6
def get_free_port(address: str, random_seed: int | None = None) -> int:
family = socket.AF_INET6 if is_valid_ipv6_address(address) else socket.AF_INET

sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind((address, 0))
# When a seed is provided, use it to deterministically pick ports from a wide range.
# Different seeds yield different port sequences,
# reducing conflicts when multiple servers launch concurrently.
if random_seed is not None:
rng = random.Random(random_seed)
for _ in range(10):
port = rng.randint(20000, 60000)
with socket.socket(family=family, type=socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.bind((address, port))
return sock.getsockname()[1]
except OSError:
continue

port = sock.getsockname()[1]
return port, sock
# Fallback: let OS choose
with socket.socket(family=family, type=socket.SOCK_STREAM) as sock:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The open sock will be closed when return, then other process may choose the same port.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and originally the reuse port flag was turned on for that. Now different async server proc will random choose from completely different port sequences to avoid collision. But still, Looks like still one ci remaining https://github.com/verl-project/verl/actions/runs/21977053742/job/63490946387?pr=5310

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only ensures multiple processes random choose port using this get_free_port, while there're other places which don't use this function, they may choose port conflict with port choose by this function.

For example, ray worker process random choose port for grpc. So we still need to bind the free port.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. changed.

sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((address, 0))
return sock.getsockname()[1]
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __init__(
# used for NCCL process group
if self.node_rank == 0:
self._master_address = self._server_address
self._master_port, self._master_sock = get_free_port(self._server_address)
# Seed with replica_rank + pid to avoid port conflicts across replicas and restarts
self._master_port = get_free_port(self._server_address, random_seed=os.getpid())
logger.info(
f"SGLangHttpServer, replica_rank: {self.replica_rank}, "
f"master address: {self._master_address}, port: {self._master_port}"
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def run_unvicorn(app: FastAPI, server_args, server_address, max_retries=5)

for i in range(max_retries):
try:
server_port, sock = get_free_port(server_address)
server_port = get_free_port(server_address)
app.server_args = server_args
config = uvicorn.Config(app, host=server_address, port=server_port, log_level="warning")
server = uvicorn.Server(config)
Expand Down
8 changes: 4 additions & 4 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ def __init__(
# used for data parallel: --data-parallel-address, --data-parallel-rpc-port
if self.node_rank == 0:
self._master_address = self._server_address
random_seed = os.getpid()
# used for torch.distributed.init_process_group
self._master_port, self._master_sock = get_free_port(self._server_address)
self._master_port = get_free_port(self._server_address, random_seed=random_seed)
# used for data parallel: --data-parallel-address, --data-parallel-rpc-port
self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address)
self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address)
self._dp_rpc_port = get_free_port(self._server_address, random_seed=random_seed + 1)
self._dp_master_port = get_free_port(self._server_address, random_seed=random_seed + 2)
else:
self._master_address = None
self._master_port = None
Expand Down Expand Up @@ -409,7 +410,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non

# 3. launch server
if self.node_rank == 0:
self._master_sock.close()
await self.run_server(server_args)
else:
# TODO: avoid connect before master_sock close
Expand Down
Loading