diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index ece1acc9b62..bcc85bac498 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -42,7 +42,8 @@ from torchrl.data.tensor_specs import Composite, NonTensor from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, EnvMetaData -from torchrl.envs.env_creator import get_env_metadata + +from torchrl.envs.env_creator import EnvCreator, get_env_metadata # legacy from torchrl.envs.libs.envpool import ( # noqa: F401 @@ -151,31 +152,20 @@ def __call__(cls, *args, **kwargs): # Wrap lambda functions with EnvCreator so they can be pickled for # multiprocessing with the spawn start method. Lambda functions cannot # be serialized with standard pickle, but EnvCreator uses cloudpickle. - from torchrl.envs.env_creator import EnvCreator - - create_env_fn = kwargs.get("create_env_fn") - if create_env_fn is None and args: - # create_env_fn is the second positional argument (after num_workers) - if len(args) >= 2: - create_env_fn = args[1] - if callable(create_env_fn): - if _is_unpicklable_lambda(create_env_fn): - args = (args[0], EnvCreator(create_env_fn)) + args[2:] - elif isinstance(create_env_fn, Sequence): - wrapped = [ - EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn - for fn in create_env_fn - ] - args = (args[0], wrapped) + args[2:] - elif create_env_fn is not None: - if callable(create_env_fn): - if _is_unpicklable_lambda(create_env_fn): - kwargs["create_env_fn"] = EnvCreator(create_env_fn) - elif isinstance(create_env_fn, Sequence): - kwargs["create_env_fn"] = [ + def _wrap_lambdas(create_env_fn): + if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn): + return EnvCreator(create_env_fn) + if isinstance(create_env_fn, Sequence): + return [ EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn for fn in create_env_fn ] + return create_env_fn + + if "create_env_fn" in kwargs: + kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"]) + elif len(args) >= 2: + args = (args[0], _wrap_lambdas(args[1])) + args[2:] return super().__call__(*args, **kwargs)