diff --git a/test/conftest.py b/test/conftest.py index 62f8010e74d..86bf469fa63 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -74,6 +74,11 @@ def set_warnings() -> None: category=UserWarning, message=r"Skipping device Apple Paravirtual device", ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"A lambda function was passed to ParallelEnv", + ) warnings.filterwarnings( "ignore", category=DeprecationWarning, diff --git a/test/test_envs.py b/test/test_envs.py index 704b389e133..2d3cf4c76ae 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -9,7 +9,7 @@ import functools import gc import importlib -import os.path +import os import pickle import random import re diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 8c45affb0c2..8283022b151 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -37,6 +37,7 @@ _check_for_faulty_process, _make_ordinal_device, logger as torchrl_logger, + rl_warnings, VERBOSE, ) from torchrl.data.tensor_specs import Composite, NonTensor @@ -152,18 +153,38 @@ def __call__(cls, *args, **kwargs): # Wrap lambda functions with EnvCreator so they can be pickled for # multiprocessing with the spawn start method. Lambda functions cannot # be serialized with standard pickle, but EnvCreator uses cloudpickle. + auto_wrap_envs = kwargs.pop("auto_wrap_envs", True) + + def _warn_lambda(): + if rl_warnings(): + warnings.warn( + "A lambda function was passed to ParallelEnv and will be wrapped " + "in an EnvCreator. This causes the environment to be instantiated " + "in the main process to extract metadata. Consider using " + "functools.partial instead, which is natively serializable and " + "avoids this overhead. To suppress this warning, set the " + "RL_WARNINGS=0 environment variable.", + category=UserWarning, + stacklevel=4, + ) + def _wrap_lambdas(create_env_fn): if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn): + _warn_lambda() return EnvCreator(create_env_fn) if isinstance(create_env_fn, Sequence): # Reuse EnvCreator for identical function objects to preserve # _single_task detection (e.g., when [lambda_fn] * 3 is passed) wrapped = {} result = [] + warned = False for fn in create_env_fn: if _is_unpicklable_lambda(fn): fn_id = id(fn) if fn_id not in wrapped: + if not warned: + _warn_lambda() + warned = True wrapped[fn_id] = EnvCreator(fn) result.append(wrapped[fn_id]) else: @@ -171,10 +192,11 @@ def _wrap_lambdas(create_env_fn): return result return create_env_fn - if "create_env_fn" in kwargs: - kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"]) - elif len(args) >= 2: - args = (args[0], _wrap_lambdas(args[1])) + args[2:] + if auto_wrap_envs: + if "create_env_fn" in kwargs: + kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"]) + elif len(args) >= 2: + args = (args[0], _wrap_lambdas(args[1])) + args[2:] return super().__call__(*args, **kwargs) @@ -241,6 +263,20 @@ class BatchedEnvBase(EnvBase): daemon (bool, optional): whether the processes should be daemonized. This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`. Defaults to ``False``. + auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as + ``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator` + to enable pickling for multiprocessing with the ``spawn`` start method. + This wrapping causes the environment to be instantiated once in the main process + (to extract metadata) before workers are started. + If this is undesirable, set ``auto_wrap_envs=False``. Otherwise, ensure your callable is + serializable (e.g., use :func:`functools.partial` instead of lambdas). + This parameter only affects :class:`~torchrl.envs.ParallelEnv`. + Defaults to ``True``. + + .. note:: + For :class:`~torchrl.envs.ParallelEnv`, it is recommended to use :func:`functools.partial` + instead of lambda functions when possible, as ``partial`` objects are natively serializable + and avoid the overhead of :class:`~torchrl.envs.EnvCreator` wrapping. .. note:: One can pass keyword arguments to each sub-environments using the following