Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def make_env():
reason="Nested spawned multiprocessed is currently failing in python 3.11. "
"See https://github.com/python/cpython/pull/108568 for info and fix.",
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.8.0"),
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
"when using spawn multiprocessing start method with file_system sharing strategy.",
)
@pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv")
@pytest.mark.parametrize("static_seed", [True, False])
def test_collector_vecnorm_envcreator(self, static_seed):
Expand Down
20 changes: 20 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9673,6 +9673,11 @@ def _test_vecnorm_subproc_auto(
def rename_t(self):
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])

@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.8.0"),
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
"when using spawn multiprocessing start method.",
)
@retry(AssertionError, tries=10, delay=0)
@pytest.mark.parametrize("nprc", [2, 5])
def test_vecnorm_parallel_auto(self, nprc):
Expand Down Expand Up @@ -9785,6 +9790,11 @@ def _run_parallelenv(parallel_env, queue_in, queue_out):
reason="Nested spawned multiprocessed is currently failing in python 3.11. "
"See https://github.com/python/cpython/pull/108568 for info and fix.",
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.8.0"),
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
"when using spawn multiprocessing start method.",
)
def test_parallelenv_vecnorm(self):
if _has_gym:
make_env = EnvCreator(
Expand Down Expand Up @@ -10051,6 +10061,11 @@ def _test_vecnorm_subproc_auto(
def rename_t(self):
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])

@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.8.0"),
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
"when using spawn multiprocessing start method.",
)
@retry(AssertionError, tries=10, delay=0)
@pytest.mark.parametrize("nprc", [2, 5])
def test_vecnorm_parallel_auto(self, nprc):
Expand Down Expand Up @@ -10170,6 +10185,11 @@ def _run_parallelenv(parallel_env, queue_in, queue_out):
reason="Nested spawned multiprocessed is currently failing in python 3.11. "
"See https://github.com/python/cpython/pull/108568 for info and fix.",
)
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.8.0"),
reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 "
"when using spawn multiprocessing start method.",
)
def test_parallelenv_vecnorm(self):
if _has_gym:
make_env = EnvCreator(
Expand Down
21 changes: 16 additions & 5 deletions torchrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,27 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import warnings
import weakref
from warnings import warn

import torch

from tensordict import set_lazy_legacy
# Silence noisy dependency warning triggered at import time on older torch stacks.
# (Emitted by tensordict when registering pytree nodes.)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"torch\.utils\._pytree\._register_pytree_node is deprecated\.",
)

from tensordict import set_lazy_legacy # noqa: E402

from torch import multiprocessing as mp
from torch.distributions.transforms import _InverseTransform, ComposeTransform
from torch import multiprocessing as mp # noqa: E402
from torch.distributions.transforms import ( # noqa: E402
_InverseTransform,
ComposeTransform,
)

torch._C._log_api_usage_once("torchrl")

Expand Down Expand Up @@ -61,8 +73,7 @@

logger = logger

# TorchRL's multiprocessing default:
# We only force "spawn" on newer PyTorch versions (see `_get_default_mp_start_method`).
# TorchRL's multiprocessing default.
_preferred_start_method = _get_default_mp_start_method()
if _preferred_start_method == "spawn":
try:
Expand Down
28 changes: 13 additions & 15 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from torch._dynamo import is_compiling


@implement_for("torch", "2.5.0")
def _get_default_mp_start_method() -> str:
"""Returns TorchRL's preferred multiprocessing start method for this torch version.

Expand All @@ -46,20 +45,6 @@ def _get_default_mp_start_method() -> str:
return "spawn"


@implement_for("torch", None, "2.5.0")
def _get_default_mp_start_method() -> str: # noqa: F811
"""Returns TorchRL's preferred multiprocessing start method for this torch version.

On older PyTorch versions we prefer ``"fork"`` when available to avoid failures
when spawning workers with non-CPU storages that must be pickled at process start.
"""
try:
mp.get_context("fork")
except ValueError:
return "spawn"
return "fork"


def _get_mp_ctx(start_method: str | None = None):
"""Return a multiprocessing context with TorchRL's preferred start method.

Expand Down Expand Up @@ -108,6 +93,19 @@ def _set_mp_start_method_if_unset(start_method: str | None = None) -> str | None
return current


@implement_for("torch", None, "2.8")
def _mp_sharing_strategy_for_spawn() -> str | None:
# On older torch stacks, pickling Process objects for "spawn" can end up
# passing file descriptors for shared storages; using "file_system" reduces
# FD passing and avoids spawn-time failures on some old Python versions.
return "file_system"


@implement_for("torch", "2.8")
def _mp_sharing_strategy_for_spawn() -> str | None: # noqa: F811
return None


def strtobool(val: Any) -> bool:
"""Convert a string representation of truth to a boolean.

Expand Down
43 changes: 38 additions & 5 deletions torchrl/collectors/_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import abc

import contextlib
import sys
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
Expand All @@ -20,6 +21,7 @@
_check_for_faulty_process,
_get_mp_ctx,
_make_process_no_warn_cls,
_mp_sharing_strategy_for_spawn,
_set_mp_start_method_if_unset,
RL_WARNINGS,
)
Expand All @@ -33,7 +35,7 @@
)
from torchrl.collectors._runner import _main_async_collector
from torchrl.collectors._single import Collector
from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool
from torchrl.collectors.utils import _make_meta_policy_cm, _TrajectoryPool
from torchrl.collectors.weight_update import WeightUpdaterBase
from torchrl.data import ReplayBuffer
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
Expand Down Expand Up @@ -945,7 +947,14 @@ def _run_processes(self) -> None:
ctx = _get_mp_ctx()
# Best-effort global init (only if unset) to keep other mp users consistent.
_set_mp_start_method_if_unset(ctx.get_start_method())

if (
sys.platform == "linux"
and sys.version_info < (3, 10)
and ctx.get_start_method() == "spawn"
):
strategy = _mp_sharing_strategy_for_spawn()
if strategy is not None:
mp.set_sharing_strategy(strategy)
queue_out = ctx.Queue(self._queue_len) # sends data from proc to main
self.procs = []
self._traj_pool = _TrajectoryPool(ctx=ctx, lock=True)
Expand Down Expand Up @@ -1004,9 +1013,11 @@ def _run_processes(self) -> None:
policy_to_send = None
cm = contextlib.nullcontext()
elif policy is not None:
# Send policy with meta-device parameters (empty structure) - schemes apply weights
# Send a stateless policy down to workers: schemes apply weights.
policy_to_send = policy
cm = _make_meta_policy(policy)
cm = _make_meta_policy_cm(
policy, mp_start_method=ctx.get_start_method()
)
else:
policy_to_send = None
cm = contextlib.nullcontext()
Expand Down Expand Up @@ -1037,7 +1048,6 @@ def _run_processes(self) -> None:
with cm:
kwargs = {
"policy_factory": policy_factory[i],
"pipe_parent": pipe_parent,
"pipe_child": pipe_child,
"queue_out": queue_out,
"create_env_fn": env_fun,
Expand Down Expand Up @@ -1128,6 +1138,29 @@ def _run_processes(self) -> None:
) from err
else:
raise err
except ValueError as err:
if "bad value(s) in fds_to_keep" in str(err):
# This error occurs on old Python versions (e.g., 3.9) with old PyTorch (e.g., 2.3)
# when using the spawn multiprocessing start method. The spawn implementation tries to
# preserve file descriptors across exec, but some descriptors may be invalid/closed.
# This is a compatibility issue with old Python multiprocessing implementations.
python_version = (
f"{sys.version_info.major}.{sys.version_info.minor}"
)
raise RuntimeError(
f"Failed to start collector worker process due to file descriptor issues "
f"with spawn multiprocessing on Python {python_version}.\n\n"
f"This is a known compatibility issue with old Python/PyTorch stacks. "
f"Consider upgrading to Python >= 3.10 and PyTorch >= 2.5, or use the 'fork' "
f"multiprocessing start method on Unix systems.\n\n"
f"Workarounds:\n"
f"- Upgrade Python to >= 3.10 and PyTorch to >= 2.5\n"
f"- On Unix systems, force fork start method:\n"
f" import torch.multiprocessing as mp\n"
f" if __name__ == '__main__':\n"
f" mp.set_start_method('fork', force=True)\n\n"
f"Upstream Python issue: https://github.com/python/cpython/issues/87706"
) from err
except _pickle.PicklingError as err:
if "<lambda>" in str(err):
raise RuntimeError(
Expand Down
2 changes: 0 additions & 2 deletions torchrl/collectors/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@


def _main_async_collector(
pipe_parent: connection.Connection,
pipe_child: connection.Connection,
queue_out: queues.Queue,
create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821
Expand Down Expand Up @@ -68,7 +67,6 @@ def _main_async_collector(
) -> None:
if collector_class is None:
collector_class = Collector
pipe_parent.close()
# init variables that will be cleared when closing
collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None

Expand Down
27 changes: 27 additions & 0 deletions torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,33 @@ def _make_meta_policy(policy: nn.Module):
return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy)


@implement_for("torch", None, "2.8")
def _make_meta_policy_cm(
policy: nn.Module, *, mp_start_method: str
) -> contextlib.AbstractContextManager:
"""Return the context manager used to make a policy 'stateless' for worker pickling.

On older PyTorch versions (<2.8), pickling meta-device storages when using the
``spawn`` start method may fail (e.g., triggering ``_share_filename_: only available on CPU``).
In that case, we avoid converting parameters/buffers to meta and simply return a no-op
context manager.
"""
if mp_start_method == "spawn":
return contextlib.nullcontext()
return _make_meta_policy(policy)


@implement_for("torch", "2.8")
def _make_meta_policy_cm( # noqa: F811
policy: nn.Module, *, mp_start_method: str
) -> contextlib.AbstractContextManager:
"""Return the context manager used to make a policy 'stateless' for worker pickling.

On PyTorch >= 2.8, meta-device policy structures can be pickled reliably under ``spawn``.
"""
return _make_meta_policy(policy)


@implement_for("torch", None, "2.5.0")
def _cast( # noqa
p: nn.Parameter | torch.Tensor,
Expand Down
16 changes: 16 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6872,6 +6872,22 @@ def __init__(
category=FutureWarning,
)

# Warn about shared memory limitations on older PyTorch
from packaging.version import parse as parse_version

if (
parse_version(torch.__version__).base_version < "2.8.0"
and shared_td is not None
):
warnings.warn(
"VecNorm with shared memory (shared_td) may not synchronize correctly "
"across processes on PyTorch < 2.8 when using the 'spawn' multiprocessing "
"start method. This is due to limitations in PyTorch's shared memory "
"implementation with the 'file_system' sharing strategy. "
"Consider upgrading to PyTorch >= 2.8 for full shared memory support.",
category=UserWarning,
)

if lock is None:
lock = mp.Lock()
if in_keys is None:
Expand Down
Loading