diff --git a/test/test_collectors.py b/test/test_collectors.py index 86a5aa25848..7ef4d98e0a1 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -1842,54 +1842,55 @@ def test_multi_sync_data_collector_ordering( use_buffers=True, ) - # Collect one batch - for batch in collector: - # Verify that each environment's observations match its env_id - # batch has shape [num_envs, frames_per_env] - # In the pre-emption case, we have that envs with odd ids are order of magnitude slower. - # These should be skipped by pre-emption (since they are the 50% slowest) - - # Recover rectangular shape of batch to uniform checks - if cat_results != "stack": - if not with_preempt: - batch = batch.reshape(num_envs, n_steps) - else: - traj_ids = batch["collector", "traj_ids"] - traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0 - # Split trajectories to recover correct shape - # thanks to having a single trajectory per env - # Pads with zeros! - batch = split_trajectories( - batch, trajectory_key=("collector", "traj_ids") + try: + # Collect one batch + for batch in collector: + # Verify that each environment's observations match its env_id + # batch has shape [num_envs, frames_per_env] + # In the pre-emption case, we have that envs with odd ids are order of magnitude slower. + # These should be skipped by pre-emption (since they are the 50% slowest) + + # Recover rectangular shape of batch to uniform checks + if cat_results != "stack": + if not with_preempt: + batch = batch.reshape(num_envs, n_steps) + else: + traj_ids = batch["collector", "traj_ids"] + traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0 + # Split trajectories to recover correct shape + # thanks to having a single trajectory per env + # Pads with zeros! + batch = split_trajectories( + batch, trajectory_key=("collector", "traj_ids") + ) + # Use -1 for padding to uniform with other preemption + is_padded = batch["collector", "traj_ids"] == 0 + batch[is_padded] = -1 + + # + for env_idx in range(num_envs): + if with_preempt and env_idx % 2 == 1: + # This is a slow env, should have been preempted after first step + assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all() + continue + # This is a fast env, no preemption happened + assert (batch["collector", "traj_ids"][env_idx] != -1).all() + + env_data = batch[env_idx] + observations = env_data["observation"] + # All observations from this environment should equal its env_id + expected_id = float(env_idx) + actual_ids = observations.flatten().unique() + + assert len(actual_ids) == 1, ( + f"Env {env_idx} should only produce observations with value {expected_id}, " + f"but got {actual_ids.tolist()}" ) - # Use -1 for padding to uniform with other preemption - is_padded = batch["collector", "traj_ids"] == 0 - batch[is_padded] = -1 - - # - for env_idx in range(num_envs): - if with_preempt and env_idx % 2 == 1: - # This is a slow env, should have been preempted after first step - assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all() - continue - # This is a fast env, no preemption happened - assert (batch["collector", "traj_ids"][env_idx] != -1).all() - - env_data = batch[env_idx] - observations = env_data["observation"] - # All observations from this environment should equal its env_id - expected_id = float(env_idx) - actual_ids = observations.flatten().unique() - - assert len(actual_ids) == 1, ( - f"Env {env_idx} should only produce observations with value {expected_id}, " - f"but got {actual_ids.tolist()}" - ) - assert ( - actual_ids[0].item() == expected_id - ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" - - collector.shutdown() + assert ( + actual_ids[0].item() == expected_id + ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" + finally: + collector.shutdown() class TestCollectorDevices: diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0e495be07d7..ece1acc9b62 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -123,6 +123,18 @@ def new_fun(self, *args, **kwargs): return new_fun +def _is_unpicklable_lambda(fn: Callable) -> bool: + """Check if a callable is a lambda function that needs cloudpickle wrapping. + + Lambda functions cannot be pickled with standard pickle, so they need to be + wrapped with EnvCreator (which uses CloudpickleWrapper) for multiprocessing. + functools.partial objects are picklable, so they don't need wrapping. + """ + if isinstance(fn, functools.partial): + return False + return callable(fn) and getattr(fn, "__name__", None) == "" + + class _PEnvMeta(_EnvPostInit): def __call__(cls, *args, **kwargs): serial_for_single = kwargs.pop("serial_for_single", False) @@ -135,6 +147,36 @@ def __call__(cls, *args, **kwargs): if num_workers == 1: # We still use a serial to keep the shape unchanged return SerialEnv(*args, **kwargs) + + # Wrap lambda functions with EnvCreator so they can be pickled for + # multiprocessing with the spawn start method. Lambda functions cannot + # be serialized with standard pickle, but EnvCreator uses cloudpickle. + from torchrl.envs.env_creator import EnvCreator + + create_env_fn = kwargs.get("create_env_fn") + if create_env_fn is None and args: + # create_env_fn is the second positional argument (after num_workers) + if len(args) >= 2: + create_env_fn = args[1] + if callable(create_env_fn): + if _is_unpicklable_lambda(create_env_fn): + args = (args[0], EnvCreator(create_env_fn)) + args[2:] + elif isinstance(create_env_fn, Sequence): + wrapped = [ + EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn + for fn in create_env_fn + ] + args = (args[0], wrapped) + args[2:] + elif create_env_fn is not None: + if callable(create_env_fn): + if _is_unpicklable_lambda(create_env_fn): + kwargs["create_env_fn"] = EnvCreator(create_env_fn) + elif isinstance(create_env_fn, Sequence): + kwargs["create_env_fn"] = [ + EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn + for fn in create_env_fn + ] + return super().__call__(*args, **kwargs)