Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions torchrl/collectors/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def cudagraph_mark_step_begin():
"INSTANTIATE_TIMEOUT",
"_MIN_TIMEOUT",
"_MAX_IDLE_COUNT",
"WEIGHT_SYNC_TIMEOUT",
"DEFAULT_EXPLORATION_TYPE",
"_is_osx",
"_Interruptor",
Expand All @@ -38,6 +39,9 @@ def cudagraph_mark_step_begin():
_TIMEOUT = 1.0
INSTANTIATE_TIMEOUT = 20
_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory
# Timeout for weight synchronization during collector init.
# Increase this when using many collectors across different CUDA devices.
WEIGHT_SYNC_TIMEOUT = float(os.environ.get("TORCHRL_WEIGHT_SYNC_TIMEOUT", 120.0))
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max))

Expand Down
2 changes: 2 additions & 0 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
cudagraph_mark_step_begin,
DEFAULT_EXPLORATION_TYPE,
INSTANTIATE_TIMEOUT,
WEIGHT_SYNC_TIMEOUT,
)

from torchrl.collectors._multi_async import MultiAsyncCollector, MultiaSyncDataCollector
Expand Down Expand Up @@ -50,6 +51,7 @@
# Constants
"_TIMEOUT",
"INSTANTIATE_TIMEOUT",
"WEIGHT_SYNC_TIMEOUT",
"_MIN_TIMEOUT",
"_MAX_IDLE_COUNT",
"DEFAULT_EXPLORATION_TYPE",
Expand Down
3 changes: 2 additions & 1 deletion torchrl/weight_update/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import multiprocessing as mp, nn

from torchrl._utils import logger as torchrl_logger
from torchrl.collectors._constants import WEIGHT_SYNC_TIMEOUT

from torchrl.weight_update.utils import _resolve_model
from torchrl.weight_update.weight_sync_schemes import (
Expand Down Expand Up @@ -97,7 +98,7 @@ def setup_connection_and_weights_on_receiver(
weights: Any = None,
model: Any = None,
strategy: Any = None,
timeout: float = 10.0,
timeout: float = WEIGHT_SYNC_TIMEOUT,
) -> TensorDictBase:
"""Receive shared memory buffer reference from sender via their per-worker queues.

Expand Down
Loading