diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index e51cd26f551..8c45affb0c2 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -156,10 +156,19 @@ def _wrap_lambdas(create_env_fn): if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn): return EnvCreator(create_env_fn) if isinstance(create_env_fn, Sequence): - return [ - EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn - for fn in create_env_fn - ] + # Reuse EnvCreator for identical function objects to preserve + # _single_task detection (e.g., when [lambda_fn] * 3 is passed) + wrapped = {} + result = [] + for fn in create_env_fn: + if _is_unpicklable_lambda(fn): + fn_id = id(fn) + if fn_id not in wrapped: + wrapped[fn_id] = EnvCreator(fn) + result.append(wrapped[fn_id]) + else: + result.append(fn) + return result return create_env_fn if "create_env_fn" in kwargs: