Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 48 additions & 47 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,54 +1842,55 @@ def test_multi_sync_data_collector_ordering(
use_buffers=True,
)

# Collect one batch
for batch in collector:
# Verify that each environment's observations match its env_id
# batch has shape [num_envs, frames_per_env]
# In the pre-emption case, we have that envs with odd ids are order of magnitude slower.
# These should be skipped by pre-emption (since they are the 50% slowest)

# Recover rectangular shape of batch to uniform checks
if cat_results != "stack":
if not with_preempt:
batch = batch.reshape(num_envs, n_steps)
else:
traj_ids = batch["collector", "traj_ids"]
traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0
# Split trajectories to recover correct shape
# thanks to having a single trajectory per env
# Pads with zeros!
batch = split_trajectories(
batch, trajectory_key=("collector", "traj_ids")
try:
# Collect one batch
for batch in collector:
# Verify that each environment's observations match its env_id
# batch has shape [num_envs, frames_per_env]
# In the pre-emption case, we have that envs with odd ids are order of magnitude slower.
# These should be skipped by pre-emption (since they are the 50% slowest)

# Recover rectangular shape of batch to uniform checks
if cat_results != "stack":
if not with_preempt:
batch = batch.reshape(num_envs, n_steps)
else:
traj_ids = batch["collector", "traj_ids"]
traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0
# Split trajectories to recover correct shape
# thanks to having a single trajectory per env
# Pads with zeros!
batch = split_trajectories(
batch, trajectory_key=("collector", "traj_ids")
)
# Use -1 for padding to uniform with other preemption
is_padded = batch["collector", "traj_ids"] == 0
batch[is_padded] = -1

#
for env_idx in range(num_envs):
if with_preempt and env_idx % 2 == 1:
# This is a slow env, should have been preempted after first step
assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all()
continue
# This is a fast env, no preemption happened
assert (batch["collector", "traj_ids"][env_idx] != -1).all()

env_data = batch[env_idx]
observations = env_data["observation"]
# All observations from this environment should equal its env_id
expected_id = float(env_idx)
actual_ids = observations.flatten().unique()

assert len(actual_ids) == 1, (
f"Env {env_idx} should only produce observations with value {expected_id}, "
f"but got {actual_ids.tolist()}"
)
# Use -1 for padding to uniform with other preemption
is_padded = batch["collector", "traj_ids"] == 0
batch[is_padded] = -1

#
for env_idx in range(num_envs):
if with_preempt and env_idx % 2 == 1:
# This is a slow env, should have been preempted after first step
assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all()
continue
# This is a fast env, no preemption happened
assert (batch["collector", "traj_ids"][env_idx] != -1).all()

env_data = batch[env_idx]
observations = env_data["observation"]
# All observations from this environment should equal its env_id
expected_id = float(env_idx)
actual_ids = observations.flatten().unique()

assert len(actual_ids) == 1, (
f"Env {env_idx} should only produce observations with value {expected_id}, "
f"but got {actual_ids.tolist()}"
)
assert (
actual_ids[0].item() == expected_id
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"

collector.shutdown()
assert (
actual_ids[0].item() == expected_id
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"
finally:
collector.shutdown()


class TestCollectorDevices:
Expand Down
42 changes: 42 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ def new_fun(self, *args, **kwargs):
return new_fun


def _is_unpicklable_lambda(fn: Callable) -> bool:
"""Check if a callable is a lambda function that needs cloudpickle wrapping.

Lambda functions cannot be pickled with standard pickle, so they need to be
wrapped with EnvCreator (which uses CloudpickleWrapper) for multiprocessing.
functools.partial objects are picklable, so they don't need wrapping.
"""
if isinstance(fn, functools.partial):
return False
return callable(fn) and getattr(fn, "__name__", None) == "<lambda>"


class _PEnvMeta(_EnvPostInit):
def __call__(cls, *args, **kwargs):
serial_for_single = kwargs.pop("serial_for_single", False)
Expand All @@ -135,6 +147,36 @@ def __call__(cls, *args, **kwargs):
if num_workers == 1:
# We still use a serial to keep the shape unchanged
return SerialEnv(*args, **kwargs)

# Wrap lambda functions with EnvCreator so they can be pickled for
# multiprocessing with the spawn start method. Lambda functions cannot
# be serialized with standard pickle, but EnvCreator uses cloudpickle.
from torchrl.envs.env_creator import EnvCreator

create_env_fn = kwargs.get("create_env_fn")
if create_env_fn is None and args:
# create_env_fn is the second positional argument (after num_workers)
if len(args) >= 2:
create_env_fn = args[1]
if callable(create_env_fn):
if _is_unpicklable_lambda(create_env_fn):
args = (args[0], EnvCreator(create_env_fn)) + args[2:]
elif isinstance(create_env_fn, Sequence):
wrapped = [
EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn
for fn in create_env_fn
]
args = (args[0], wrapped) + args[2:]
elif create_env_fn is not None:
if callable(create_env_fn):
if _is_unpicklable_lambda(create_env_fn):
kwargs["create_env_fn"] = EnvCreator(create_env_fn)
elif isinstance(create_env_fn, Sequence):
kwargs["create_env_fn"] = [
EnvCreator(fn) if _is_unpicklable_lambda(fn) else fn
for fn in create_env_fn
]

return super().__call__(*args, **kwargs)


Expand Down
Loading