Skip to content

Commit fb2a027

Browse files
committed
[Feature] auto_wrap_envs in PEnv
ghstack-source-id: 8dceb5d Pull-Request: #3284
1 parent acdd493 commit fb2a027

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

test/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,11 @@ def set_warnings() -> None:
7474
category=UserWarning,
7575
message=r"Skipping device Apple Paravirtual device",
7676
)
77+
warnings.filterwarnings(
78+
"ignore",
79+
category=UserWarning,
80+
message=r"A lambda function was passed to ParallelEnv",
81+
)
7782
warnings.filterwarnings(
7883
"ignore",
7984
category=DeprecationWarning,

test/test_envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import functools
1010
import gc
1111
import importlib
12-
import os.path
12+
import os
1313
import pickle
1414
import random
1515
import re

torchrl/envs/batched_envs.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_check_for_faulty_process,
3838
_make_ordinal_device,
3939
logger as torchrl_logger,
40+
rl_warnings,
4041
VERBOSE,
4142
)
4243
from torchrl.data.tensor_specs import Composite, NonTensor
@@ -152,29 +153,50 @@ def __call__(cls, *args, **kwargs):
152153
# Wrap lambda functions with EnvCreator so they can be pickled for
153154
# multiprocessing with the spawn start method. Lambda functions cannot
154155
# be serialized with standard pickle, but EnvCreator uses cloudpickle.
156+
auto_wrap_envs = kwargs.pop("auto_wrap_envs", True)
157+
158+
def _warn_lambda():
159+
if rl_warnings():
160+
warnings.warn(
161+
"A lambda function was passed to ParallelEnv and will be wrapped "
162+
"in an EnvCreator. This causes the environment to be instantiated "
163+
"in the main process to extract metadata. Consider using "
164+
"functools.partial instead, which is natively serializable and "
165+
"avoids this overhead. To suppress this warning, set the "
166+
"RL_WARNINGS=0 environment variable.",
167+
category=UserWarning,
168+
stacklevel=4,
169+
)
170+
155171
def _wrap_lambdas(create_env_fn):
156172
if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn):
173+
_warn_lambda()
157174
return EnvCreator(create_env_fn)
158175
if isinstance(create_env_fn, Sequence):
159176
# Reuse EnvCreator for identical function objects to preserve
160177
# _single_task detection (e.g., when [lambda_fn] * 3 is passed)
161178
wrapped = {}
162179
result = []
180+
warned = False
163181
for fn in create_env_fn:
164182
if _is_unpicklable_lambda(fn):
165183
fn_id = id(fn)
166184
if fn_id not in wrapped:
185+
if not warned:
186+
_warn_lambda()
187+
warned = True
167188
wrapped[fn_id] = EnvCreator(fn)
168189
result.append(wrapped[fn_id])
169190
else:
170191
result.append(fn)
171192
return result
172193
return create_env_fn
173194

174-
if "create_env_fn" in kwargs:
175-
kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
176-
elif len(args) >= 2:
177-
args = (args[0], _wrap_lambdas(args[1])) + args[2:]
195+
if auto_wrap_envs:
196+
if "create_env_fn" in kwargs:
197+
kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
198+
elif len(args) >= 2:
199+
args = (args[0], _wrap_lambdas(args[1])) + args[2:]
178200

179201
return super().__call__(*args, **kwargs)
180202

@@ -241,6 +263,20 @@ class BatchedEnvBase(EnvBase):
241263
daemon (bool, optional): whether the processes should be daemonized.
242264
This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`.
243265
Defaults to ``False``.
266+
auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as
267+
``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator`
268+
to enable pickling for multiprocessing with the ``spawn`` start method.
269+
This wrapping causes the environment to be instantiated once in the main process
270+
(to extract metadata) before workers are started.
271+
If this is undesirable, set ``auto_wrap_envs=False``. Otherwise, ensure your callable is
272+
serializable (e.g., use :func:`functools.partial` instead of lambdas).
273+
This parameter only affects :class:`~torchrl.envs.ParallelEnv`.
274+
Defaults to ``True``.
275+
276+
.. note::
277+
For :class:`~torchrl.envs.ParallelEnv`, it is recommended to use :func:`functools.partial`
278+
instead of lambda functions when possible, as ``partial`` objects are natively serializable
279+
and avoid the overhead of :class:`~torchrl.envs.EnvCreator` wrapping.
244280
245281
.. note::
246282
One can pass keyword arguments to each sub-environments using the following

0 commit comments

Comments
 (0)