Skip to content

Commit 6e7d70f

Browse files
authored
[bug] fix rank/local rank parsing for docker env (#1747)
1 parent 89f7c61 commit 6e7d70f

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/oumi/core/distributed.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,30 @@ def _get_use_orig_params(config: TrainingConfig) -> bool:
7474
#
7575
# Process Info
7676
#
77+
def _parse_rank(rank: Optional[str]) -> int:
78+
"""Parse the rank from the environment variable."""
79+
if not rank:
80+
return 0
81+
82+
# -1 is a special value that means "not set".
83+
# It's used by the Accelerate launcher.
84+
# Defaulting to 0.
85+
if rank.strip() == "-1":
86+
return 0
87+
88+
if not rank.isdigit():
89+
raise ValueError(f"Rank must be a number. Actual: {rank}.")
90+
91+
return int(rank)
92+
93+
7794
@functools.cache # same as @cache added in Python 3.9
7895
def get_device_rank_info() -> DeviceRankInfo:
7996
"""Returns device rank and world size."""
8097
world_size = int(os.environ.get("WORLD_SIZE", 1))
8198
if world_size <= 0:
8299
raise ValueError(f"WORLD_SIZE must be positive. Actual: {world_size}.")
83-
rank = int(os.environ.get("RANK", 0))
100+
rank = _parse_rank(os.environ.get("RANK"))
84101
if rank < 0 or rank >= world_size:
85102
raise ValueError(
86103
f"RANK must be within this range [0, {world_size}). Actual: {rank}."
@@ -94,7 +111,7 @@ def get_device_rank_info() -> DeviceRankInfo:
94111
# Per https://pytorch.org/docs/stable/elastic/run.html
95112
# NEVER hard code any assumptions about the stable-ness of ranks or
96113
# some correlation between RANK and LOCAL_RANK.
97-
local_rank = int(os.environ.get("LOCAL_RANK", 0))
114+
local_rank = _parse_rank(os.environ.get("LOCAL_RANK"))
98115
if local_rank < 0 or local_rank >= local_world_size:
99116
raise ValueError(
100117
f"LOCAL_RANK must be within this range [0, {local_world_size}). "

tests/unit/core/test_distributed.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from oumi.core.configs.params.training_params import TrainingParams
1616
from oumi.core.distributed import (
1717
DeviceRankInfo,
18+
_parse_rank,
1819
all_gather_object,
1920
estimate_dataloader_num_workers,
2021
get_accelerate_env_vars,
@@ -484,3 +485,33 @@ def test_prepare_accelerate_fsdp_run_override():
484485
"`EXISTING_VALUE`, overriding to new value `NO`."
485486
)
486487
assert env_vars == expected_env_vars
488+
489+
490+
@pytest.mark.parametrize(
491+
"rank_input,expected",
492+
[("1", 1), ("5", 5), ("42", 42), ("100", 100), ("0", 0), ("-1", 0), (" -1 ", 0)],
493+
)
494+
def test_parse_rank(rank_input, expected):
495+
"""Test that _parse_rank returns correct integer for valid positive rank strings."""
496+
assert _parse_rank(rank_input) == expected
497+
498+
499+
def test_parse_rank_invalid_non_digit():
500+
"""Test that _parse_rank raises ValueError for non-digit strings."""
501+
with pytest.raises(ValueError, match=r"Rank must be a number\. Actual: abc\."):
502+
_parse_rank("abc")
503+
504+
with pytest.raises(ValueError, match=r"Rank must be a number\. Actual: 1a\."):
505+
_parse_rank("1a")
506+
507+
with pytest.raises(ValueError, match=r"Rank must be a number\. Actual: a1\."):
508+
_parse_rank("a1")
509+
510+
511+
def test_parse_rank_invalid_negative():
512+
"""Test that _parse_rank raises ValueError for negative numbers (except -1)."""
513+
with pytest.raises(ValueError, match=r"Rank must be a number\. Actual: -2\."):
514+
_parse_rank("-2")
515+
516+
with pytest.raises(ValueError, match=r"Rank must be a number\. Actual: -10\."):
517+
_parse_rank("-10")

0 commit comments

Comments
 (0)