diff --git a/README.md b/README.md index d713eb77..07f5a1be 100644 --- a/README.md +++ b/README.md @@ -90,3 +90,4 @@ HF_HUB_ETAG_TIMEOUT=500 | `GLOBAL_PORT` | Port number of the global store. | `None` | | `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` | | `GLOBAL_RANK` | Rank of the process in the global process group. | `0` | +| `ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS` | Number of seconds before the global store operations timeout | `300` | diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index 6a7fdded..034f6fde 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -1,3 +1,4 @@ +import os from torch.distributed.device_mesh import init_device_mesh from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger @@ -8,7 +9,7 @@ from torch.testing._internal.distributed.fake_pg import FakeProcessGroup -TCPSTORE_TIMEOUT = timedelta(seconds=10) +TCPSTORE_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS", "300"))) MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit