Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def set_warnings() -> None:
category=UserWarning,
message=r"Skipping device Apple Paravirtual device",
)
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=r"A lambda function was passed to ParallelEnv",
)
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
Expand Down
2 changes: 1 addition & 1 deletion test/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import functools
import gc
import importlib
import os.path
import os
import pickle
import random
import re
Expand Down
44 changes: 40 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
_check_for_faulty_process,
_make_ordinal_device,
logger as torchrl_logger,
rl_warnings,
VERBOSE,
)
from torchrl.data.tensor_specs import Composite, NonTensor
Expand Down Expand Up @@ -152,29 +153,50 @@ def __call__(cls, *args, **kwargs):
# Wrap lambda functions with EnvCreator so they can be pickled for
# multiprocessing with the spawn start method. Lambda functions cannot
# be serialized with standard pickle, but EnvCreator uses cloudpickle.
auto_wrap_envs = kwargs.pop("auto_wrap_envs", True)

def _warn_lambda():
if rl_warnings():
warnings.warn(
"A lambda function was passed to ParallelEnv and will be wrapped "
"in an EnvCreator. This causes the environment to be instantiated "
"in the main process to extract metadata. Consider using "
"functools.partial instead, which is natively serializable and "
"avoids this overhead. To suppress this warning, set the "
"RL_WARNINGS=0 environment variable.",
category=UserWarning,
stacklevel=4,
)

def _wrap_lambdas(create_env_fn):
if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn):
_warn_lambda()
return EnvCreator(create_env_fn)
if isinstance(create_env_fn, Sequence):
# Reuse EnvCreator for identical function objects to preserve
# _single_task detection (e.g., when [lambda_fn] * 3 is passed)
wrapped = {}
result = []
warned = False
for fn in create_env_fn:
if _is_unpicklable_lambda(fn):
fn_id = id(fn)
if fn_id not in wrapped:
if not warned:
_warn_lambda()
warned = True
wrapped[fn_id] = EnvCreator(fn)
result.append(wrapped[fn_id])
else:
result.append(fn)
return result
return create_env_fn

if "create_env_fn" in kwargs:
kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
elif len(args) >= 2:
args = (args[0], _wrap_lambdas(args[1])) + args[2:]
if auto_wrap_envs:
if "create_env_fn" in kwargs:
kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
elif len(args) >= 2:
args = (args[0], _wrap_lambdas(args[1])) + args[2:]

return super().__call__(*args, **kwargs)

Expand Down Expand Up @@ -241,6 +263,20 @@ class BatchedEnvBase(EnvBase):
daemon (bool, optional): whether the processes should be daemonized.
This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`.
Defaults to ``False``.
auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as
``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator`
to enable pickling for multiprocessing with the ``spawn`` start method.
This wrapping causes the environment to be instantiated once in the main process
(to extract metadata) before workers are started.
If this is undesirable, set ``auto_wrap_envs=False``. Otherwise, ensure your callable is
serializable (e.g., use :func:`functools.partial` instead of lambdas).
This parameter only affects :class:`~torchrl.envs.ParallelEnv`.
Defaults to ``True``.

.. note::
For :class:`~torchrl.envs.ParallelEnv`, it is recommended to use :func:`functools.partial`
instead of lambda functions when possible, as ``partial`` objects are natively serializable
and avoid the overhead of :class:`~torchrl.envs.EnvCreator` wrapping.

.. note::
One can pass keyword arguments to each sub-environments using the following
Expand Down
Loading