Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 13 additions & 23 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from torchrl.data.tensor_specs import Composite, NonTensor
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData
from torchrl.envs.env_creator import get_env_metadata

from torchrl.envs.env_creator import EnvCreator, get_env_metadata

# legacy
from torchrl.envs.libs.envpool import ( # noqa: F401
Expand Down Expand Up @@ -151,31 +152,20 @@ def __call__(cls, *args, **kwargs):
# Wrap lambda functions with EnvCreator so they can be pickled for
# multiprocessing with the spawn start method. Lambda functions cannot
# be serialized with standard pickle, but EnvCreator uses cloudpickle.
from torchrl.envs.env_creator import EnvCreator

create_env_fn = kwargs.get("create_env_fn")
if create_env_fn is None and args:
# create_env_fn is the second positional argument (after num_workers)
if len(args) >= 2:
create_env_fn = args[1]
if callable(create_env_fn):
if _is_unpicklable_lambda(create_env_fn):
args = (args[0], EnvCreator(create_env_fn)) + args[2:]
elif isinstance(create_env_fn, Sequence):
wrapped = [
EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn
for fn in create_env_fn
]
args = (args[0], wrapped) + args[2:]
elif create_env_fn is not None:
if callable(create_env_fn):
if _is_unpicklable_lambda(create_env_fn):
kwargs["create_env_fn"] = EnvCreator(create_env_fn)
elif isinstance(create_env_fn, Sequence):
kwargs["create_env_fn"] = [
def _wrap_lambdas(create_env_fn):
if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn):
return EnvCreator(create_env_fn)
if isinstance(create_env_fn, Sequence):
return [
EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn
for fn in create_env_fn
]
return create_env_fn

if "create_env_fn" in kwargs:
kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"])
elif len(args) >= 2:
args = (args[0], _wrap_lambdas(args[1])) + args[2:]

return super().__call__(*args, **kwargs)

Expand Down
Loading