diff --git a/torchrl/collectors/_constants.py b/torchrl/collectors/_constants.py index 1587d800166..1b632cfc1c8 100644 --- a/torchrl/collectors/_constants.py +++ b/torchrl/collectors/_constants.py @@ -28,6 +28,7 @@ def cudagraph_mark_step_begin(): "INSTANTIATE_TIMEOUT", "_MIN_TIMEOUT", "_MAX_IDLE_COUNT", + "WEIGHT_SYNC_TIMEOUT", "DEFAULT_EXPLORATION_TYPE", "_is_osx", "_Interruptor", @@ -38,6 +39,9 @@ def cudagraph_mark_step_begin(): _TIMEOUT = 1.0 INSTANTIATE_TIMEOUT = 20 _MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory +# Timeout for weight synchronization during collector init. +# Increase this when using many collectors across different CUDA devices. +WEIGHT_SYNC_TIMEOUT = float(os.environ.get("TORCHRL_WEIGHT_SYNC_TIMEOUT", 120.0)) # MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. _MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 864d1cbd4f0..25cd9fdf9da 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -18,6 +18,7 @@ cudagraph_mark_step_begin, DEFAULT_EXPLORATION_TYPE, INSTANTIATE_TIMEOUT, + WEIGHT_SYNC_TIMEOUT, ) from torchrl.collectors._multi_async import MultiAsyncCollector, MultiaSyncDataCollector @@ -50,6 +51,7 @@ # Constants "_TIMEOUT", "INSTANTIATE_TIMEOUT", + "WEIGHT_SYNC_TIMEOUT", "_MIN_TIMEOUT", "_MAX_IDLE_COUNT", "DEFAULT_EXPLORATION_TYPE", diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 0b0fe54875f..cb575ae2e80 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -11,6 +11,7 @@ from torch import multiprocessing as mp, nn from torchrl._utils import logger as torchrl_logger +from torchrl.collectors._constants import WEIGHT_SYNC_TIMEOUT from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( @@ -97,7 +98,7 @@ def setup_connection_and_weights_on_receiver( weights: Any = None, model: Any = None, strategy: Any = None, - timeout: float = 10.0, + timeout: float = WEIGHT_SYNC_TIMEOUT, ) -> TensorDictBase: """Receive shared memory buffer reference from sender via their per-worker queues.