diff --git a/test/test_collectors.py b/test/test_collectors.py index 1981ac5a552..7b3604a6615 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -616,6 +616,11 @@ def make_env(): reason="Nested spawned multiprocessed is currently failing in python 3.11. " "See https://github.com/python/cpython/pull/108568 for info and fix.", ) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.8.0"), + reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 " + "when using spawn multiprocessing start method with file_system sharing strategy.", + ) @pytest.mark.skipif(not _has_gym, reason="test designed with GymEnv") @pytest.mark.parametrize("static_seed", [True, False]) def test_collector_vecnorm_envcreator(self, static_seed): diff --git a/test/test_transforms.py b/test/test_transforms.py index 73029e9f2c4..63070c96220 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -9673,6 +9673,11 @@ def _test_vecnorm_subproc_auto( def rename_t(self): return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")]) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.8.0"), + reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 " + "when using spawn multiprocessing start method.", + ) @retry(AssertionError, tries=10, delay=0) @pytest.mark.parametrize("nprc", [2, 5]) def test_vecnorm_parallel_auto(self, nprc): @@ -9785,6 +9790,11 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): reason="Nested spawned multiprocessed is currently failing in python 3.11. " "See https://github.com/python/cpython/pull/108568 for info and fix.", ) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.8.0"), + reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 " + "when using spawn multiprocessing start method.", + ) def test_parallelenv_vecnorm(self): if _has_gym: make_env = EnvCreator( @@ -10051,6 +10061,11 @@ def _test_vecnorm_subproc_auto( def rename_t(self): return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")]) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.8.0"), + reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 " + "when using spawn multiprocessing start method.", + ) @retry(AssertionError, tries=10, delay=0) @pytest.mark.parametrize("nprc", [2, 5]) def test_vecnorm_parallel_auto(self, nprc): @@ -10170,6 +10185,11 @@ def _run_parallelenv(parallel_env, queue_in, queue_out): reason="Nested spawned multiprocessed is currently failing in python 3.11. " "See https://github.com/python/cpython/pull/108568 for info and fix.", ) + @pytest.mark.skipif( + TORCH_VERSION < version.parse("2.8.0"), + reason="VecNorm shared memory synchronization requires PyTorch >= 2.8 " + "when using spawn multiprocessing start method.", + ) def test_parallelenv_vecnorm(self): if _has_gym: make_env = EnvCreator( diff --git a/torchrl/__init__.py b/torchrl/__init__.py index b7fc2aa74f6..a7372f9ad69 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -3,15 +3,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os +import warnings import weakref from warnings import warn import torch -from tensordict import set_lazy_legacy +# Silence noisy dependency warning triggered at import time on older torch stacks. +# (Emitted by tensordict when registering pytree nodes.) +warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"torch\.utils\._pytree\._register_pytree_node is deprecated\.", +) + +from tensordict import set_lazy_legacy # noqa: E402 -from torch import multiprocessing as mp -from torch.distributions.transforms import _InverseTransform, ComposeTransform +from torch import multiprocessing as mp # noqa: E402 +from torch.distributions.transforms import ( # noqa: E402 + _InverseTransform, + ComposeTransform, +) torch._C._log_api_usage_once("torchrl") @@ -61,8 +73,7 @@ logger = logger -# TorchRL's multiprocessing default: -# We only force "spawn" on newer PyTorch versions (see `_get_default_mp_start_method`). +# TorchRL's multiprocessing default. _preferred_start_method = _get_default_mp_start_method() if _preferred_start_method == "spawn": try: diff --git a/torchrl/_utils.py b/torchrl/_utils.py index fe2ee81c853..64e925330c6 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -36,7 +36,6 @@ from torch._dynamo import is_compiling -@implement_for("torch", "2.5.0") def _get_default_mp_start_method() -> str: """Returns TorchRL's preferred multiprocessing start method for this torch version. @@ -46,20 +45,6 @@ def _get_default_mp_start_method() -> str: return "spawn" -@implement_for("torch", None, "2.5.0") -def _get_default_mp_start_method() -> str: # noqa: F811 - """Returns TorchRL's preferred multiprocessing start method for this torch version. - - On older PyTorch versions we prefer ``"fork"`` when available to avoid failures - when spawning workers with non-CPU storages that must be pickled at process start. - """ - try: - mp.get_context("fork") - except ValueError: - return "spawn" - return "fork" - - def _get_mp_ctx(start_method: str | None = None): """Return a multiprocessing context with TorchRL's preferred start method. @@ -108,6 +93,19 @@ def _set_mp_start_method_if_unset(start_method: str | None = None) -> str | None return current +@implement_for("torch", None, "2.8") +def _mp_sharing_strategy_for_spawn() -> str | None: + # On older torch stacks, pickling Process objects for "spawn" can end up + # passing file descriptors for shared storages; using "file_system" reduces + # FD passing and avoids spawn-time failures on some old Python versions. + return "file_system" + + +@implement_for("torch", "2.8") +def _mp_sharing_strategy_for_spawn() -> str | None: # noqa: F811 + return None + + def strtobool(val: Any) -> bool: """Convert a string representation of truth to a boolean. diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index ae803f194b5..dba4ab34269 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -4,6 +4,7 @@ import abc import contextlib +import sys import warnings from collections import OrderedDict from collections.abc import Callable, Mapping, Sequence @@ -20,6 +21,7 @@ _check_for_faulty_process, _get_mp_ctx, _make_process_no_warn_cls, + _mp_sharing_strategy_for_spawn, _set_mp_start_method_if_unset, RL_WARNINGS, ) @@ -33,7 +35,7 @@ ) from torchrl.collectors._runner import _main_async_collector from torchrl.collectors._single import Collector -from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool +from torchrl.collectors.utils import _make_meta_policy_cm, _TrajectoryPool from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING @@ -945,7 +947,14 @@ def _run_processes(self) -> None: ctx = _get_mp_ctx() # Best-effort global init (only if unset) to keep other mp users consistent. _set_mp_start_method_if_unset(ctx.get_start_method()) - + if ( + sys.platform == "linux" + and sys.version_info < (3, 10) + and ctx.get_start_method() == "spawn" + ): + strategy = _mp_sharing_strategy_for_spawn() + if strategy is not None: + mp.set_sharing_strategy(strategy) queue_out = ctx.Queue(self._queue_len) # sends data from proc to main self.procs = [] self._traj_pool = _TrajectoryPool(ctx=ctx, lock=True) @@ -1004,9 +1013,11 @@ def _run_processes(self) -> None: policy_to_send = None cm = contextlib.nullcontext() elif policy is not None: - # Send policy with meta-device parameters (empty structure) - schemes apply weights + # Send a stateless policy down to workers: schemes apply weights. policy_to_send = policy - cm = _make_meta_policy(policy) + cm = _make_meta_policy_cm( + policy, mp_start_method=ctx.get_start_method() + ) else: policy_to_send = None cm = contextlib.nullcontext() @@ -1037,7 +1048,6 @@ def _run_processes(self) -> None: with cm: kwargs = { "policy_factory": policy_factory[i], - "pipe_parent": pipe_parent, "pipe_child": pipe_child, "queue_out": queue_out, "create_env_fn": env_fun, @@ -1128,6 +1138,29 @@ def _run_processes(self) -> None: ) from err else: raise err + except ValueError as err: + if "bad value(s) in fds_to_keep" in str(err): + # This error occurs on old Python versions (e.g., 3.9) with old PyTorch (e.g., 2.3) + # when using the spawn multiprocessing start method. The spawn implementation tries to + # preserve file descriptors across exec, but some descriptors may be invalid/closed. + # This is a compatibility issue with old Python multiprocessing implementations. + python_version = ( + f"{sys.version_info.major}.{sys.version_info.minor}" + ) + raise RuntimeError( + f"Failed to start collector worker process due to file descriptor issues " + f"with spawn multiprocessing on Python {python_version}.\n\n" + f"This is a known compatibility issue with old Python/PyTorch stacks. " + f"Consider upgrading to Python >= 3.10 and PyTorch >= 2.5, or use the 'fork' " + f"multiprocessing start method on Unix systems.\n\n" + f"Workarounds:\n" + f"- Upgrade Python to >= 3.10 and PyTorch to >= 2.5\n" + f"- On Unix systems, force fork start method:\n" + f" import torch.multiprocessing as mp\n" + f" if __name__ == '__main__':\n" + f" mp.set_start_method('fork', force=True)\n\n" + f"Upstream Python issue: https://github.com/python/cpython/issues/87706" + ) from err except _pickle.PicklingError as err: if "" in str(err): raise RuntimeError( diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 0f6d1ec57c9..8b571f6ce6d 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -34,7 +34,6 @@ def _main_async_collector( - pipe_parent: connection.Connection, pipe_child: connection.Connection, queue_out: queues.Queue, create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 @@ -68,7 +67,6 @@ def _main_async_collector( ) -> None: if collector_class is None: collector_class = Collector - pipe_parent.close() # init variables that will be cleared when closing collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index ef6aa60aad2..0b45a610889 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -298,6 +298,33 @@ def _make_meta_policy(policy: nn.Module): return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) +@implement_for("torch", None, "2.8") +def _make_meta_policy_cm( + policy: nn.Module, *, mp_start_method: str +) -> contextlib.AbstractContextManager: + """Return the context manager used to make a policy 'stateless' for worker pickling. + + On older PyTorch versions (<2.8), pickling meta-device storages when using the + ``spawn`` start method may fail (e.g., triggering ``_share_filename_: only available on CPU``). + In that case, we avoid converting parameters/buffers to meta and simply return a no-op + context manager. + """ + if mp_start_method == "spawn": + return contextlib.nullcontext() + return _make_meta_policy(policy) + + +@implement_for("torch", "2.8") +def _make_meta_policy_cm( # noqa: F811 + policy: nn.Module, *, mp_start_method: str +) -> contextlib.AbstractContextManager: + """Return the context manager used to make a policy 'stateless' for worker pickling. + + On PyTorch >= 2.8, meta-device policy structures can be pickled reliably under ``spawn``. + """ + return _make_meta_policy(policy) + + @implement_for("torch", None, "2.5.0") def _cast( # noqa p: nn.Parameter | torch.Tensor, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3854336ad7d..92b5317939d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -6872,6 +6872,22 @@ def __init__( category=FutureWarning, ) + # Warn about shared memory limitations on older PyTorch + from packaging.version import parse as parse_version + + if ( + parse_version(torch.__version__).base_version < "2.8.0" + and shared_td is not None + ): + warnings.warn( + "VecNorm with shared memory (shared_td) may not synchronize correctly " + "across processes on PyTorch < 2.8 when using the 'spawn' multiprocessing " + "start method. This is due to limitations in PyTorch's shared memory " + "implementation with the 'file_system' sharing strategy. " + "Consider upgrading to PyTorch >= 2.8 for full shared memory support.", + category=UserWarning, + ) + if lock is None: lock = mp.Lock() if in_keys is None: