Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then
apt-get install -y libfreetype6-dev pkg-config

apt-get install -y libglfw3 libosmesa6 libglew-dev
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg

if [ "${CU_VERSION:-}" == cpu ] ; then
apt-get upgrade -y libstdc++6
Expand Down Expand Up @@ -205,15 +205,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
uv_pip_install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
else
uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
uv_pip_install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu
uv_pip_install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cpu
else
uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
uv_pip_install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION
fi
else
printf "Failed to install pytorch"
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux/scripts/run_setup_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ else
fi

if [[ "$TORCH_VERSION" == "nightly" ]]; then
uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
uv_pip_install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
else
uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu
uv_pip_install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cpu
fi

# tensordict is a hard dependency of torchrl; install it explicitly since we test
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux_libs/scripts_minari/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
uv pip install torch --index-url https://download.pytorch.org/whl/cpu
uv pip install torch --index-url https://download.pytorch.org/whl/cpu -U
else
uv pip install torch --index-url https://download.pytorch.org/whl/cu128
uv pip install torch --index-url https://download.pytorch.org/whl/cu128 -U
fi
else
printf "Failed to install pytorch"
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_libs/scripts_vd4rl/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
else
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U
fi
else
printf "Failed to install pytorch"
Expand Down
4 changes: 2 additions & 2 deletions .github/unittest/linux_libs/scripts_vmas/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
else
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
pip3 install torch --index-url https://download.pytorch.org/whl/cu128 -U
fi
else
printf "Failed to install pytorch"
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux_optdeps/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then
apt-get install -y vim git wget cmake

apt-get install -y libglfw3 libosmesa6 libglew-dev
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2
apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg

if [ "${CU_VERSION:-}" == cpu ] ; then
# solves version `GLIBCXX_3.4.29' not found for tensorboard
Expand Down
8 changes: 4 additions & 4 deletions .github/unittest/linux_sota/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ git submodule sync && git submodule update --init --recursive
printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
uv pip install --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/cpu
uv pip install --upgrade --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/cpu
else
uv pip install --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
uv pip install --upgrade --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
uv pip install torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/cpu
uv pip install --upgrade torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/cpu
else
uv pip install torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/$CU_VERSION
uv pip install --upgrade torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/$CU_VERSION
fi
else
printf "Failed to install pytorch"
Expand Down
6 changes: 3 additions & 3 deletions .github/unittest/windows_optdepts/scripts/unittest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ python -m pip install "numpy<2.0"
printf "Installing PyTorch with %s\n" "${cudatoolkit}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if $torch_cuda ; then
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 -U
else
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if $torch_cuda ; then
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
python -m pip install torch --index-url https://download.pytorch.org/whl/cu118 -U
else
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu -U
fi
else
printf "Failed to install pytorch"
Expand Down
9 changes: 0 additions & 9 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torch import nn, vmap

from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
from torchrl.envs.libs.envpool import _has_envpool
Expand Down Expand Up @@ -594,14 +593,6 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()


def decorate_thread_sub_func(func, num_threads):
def new_func(*args, **kwargs):
assert torch.get_num_threads() == num_threads
return func(*args, **kwargs)

return CloudpickleWrapper(new_func)


class LSTMNet(nn.Module):
"""An embedder for an LSTM preceded by an MLP.

Expand Down
98 changes: 49 additions & 49 deletions test/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
SafeModule,
)
from torchrl.testing.modules import BiasModule, NonSerializableBiasModule
from torchrl.testing.mp_helpers import decorate_thread_sub_func
from torchrl.weight_update import (
MultiProcessWeightSyncScheme,
SharedMemWeightSyncScheme,
Expand All @@ -99,7 +100,6 @@
from pytorch.rl.test._utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
Expand All @@ -112,7 +112,6 @@
from _utils_internal import (
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
get_available_devices,
get_default_devices,
Expand Down Expand Up @@ -1843,54 +1842,55 @@ def test_multi_sync_data_collector_ordering(
use_buffers=True,
)

# Collect one batch
for batch in collector:
# Verify that each environment's observations match its env_id
# batch has shape [num_envs, frames_per_env]
# In the pre-emption case, we have that envs with odd ids are order of magnitude slower.
# These should be skipped by pre-emption (since they are the 50% slowest)

# Recover rectangular shape of batch to uniform checks
if cat_results != "stack":
if not with_preempt:
batch = batch.reshape(num_envs, n_steps)
else:
traj_ids = batch["collector", "traj_ids"]
traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0
# Split trajectories to recover correct shape
# thanks to having a single trajectory per env
# Pads with zeros!
batch = split_trajectories(
batch, trajectory_key=("collector", "traj_ids")
try:
# Collect one batch
for batch in collector:
# Verify that each environment's observations match its env_id
# batch has shape [num_envs, frames_per_env]
# In the pre-emption case, we have that envs with odd ids are order of magnitude slower.
# These should be skipped by pre-emption (since they are the 50% slowest)

# Recover rectangular shape of batch to uniform checks
if cat_results != "stack":
if not with_preempt:
batch = batch.reshape(num_envs, n_steps)
else:
traj_ids = batch["collector", "traj_ids"]
traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0
# Split trajectories to recover correct shape
# thanks to having a single trajectory per env
# Pads with zeros!
batch = split_trajectories(
batch, trajectory_key=("collector", "traj_ids")
)
# Use -1 for padding to uniform with other preemption
is_padded = batch["collector", "traj_ids"] == 0
batch[is_padded] = -1

#
for env_idx in range(num_envs):
if with_preempt and env_idx % 2 == 1:
# This is a slow env, should have been preempted after first step
assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all()
continue
# This is a fast env, no preemption happened
assert (batch["collector", "traj_ids"][env_idx] != -1).all()

env_data = batch[env_idx]
observations = env_data["observation"]
# All observations from this environment should equal its env_id
expected_id = float(env_idx)
actual_ids = observations.flatten().unique()

assert len(actual_ids) == 1, (
f"Env {env_idx} should only produce observations with value {expected_id}, "
f"but got {actual_ids.tolist()}"
)
# Use -1 for padding to uniform with other preemption
is_padded = batch["collector", "traj_ids"] == 0
batch[is_padded] = -1

#
for env_idx in range(num_envs):
if with_preempt and env_idx % 2 == 1:
# This is a slow env, should have been preempted after first step
assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all()
continue
# This is a fast env, no preemption happened
assert (batch["collector", "traj_ids"][env_idx] != -1).all()

env_data = batch[env_idx]
observations = env_data["observation"]
# All observations from this environment should equal its env_id
expected_id = float(env_idx)
actual_ids = observations.flatten().unique()

assert len(actual_ids) == 1, (
f"Env {env_idx} should only produce observations with value {expected_id}, "
f"but got {actual_ids.tolist()}"
)
assert (
actual_ids[0].item() == expected_id
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"

collector.shutdown()
assert (
actual_ids[0].item() == expected_id
), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}"
finally:
collector.shutdown()


class TestCollectorDevices:
Expand Down
22 changes: 9 additions & 13 deletions test/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def check_no_lingering_multiprocessing_resources(request):
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand All @@ -191,7 +190,6 @@ def check_no_lingering_multiprocessing_resources(request):
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand Down Expand Up @@ -3405,25 +3403,23 @@ class TestLibThreading:
)
def test_num_threads(self):
gc.collect()
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
batched_envs._run_worker_pipe_shared_mem, num_threads=3
)
num_threads = torch.get_num_threads()
try:
env = ParallelEnv(
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
)
# Wrap the env factory to check thread count inside the subprocess.
# The env is created AFTER torch.set_num_threads() is called in the worker.
def make_env():
assert (
torch.get_num_threads() == 3
), f"Expected 3 threads, got {torch.get_num_threads()}"
return ContinuousActionVecMockEnv()

env = ParallelEnv(2, make_env, num_sub_threads=3, num_threads=7)
# We could test that the number of threads isn't changed until we start the procs.
# Even though it's unlikely that we have 7 threads, we still disable this for safety
# assert torch.get_num_threads() != 7
env.rollout(3)
assert torch.get_num_threads() == 7
finally:
# reset vals
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
torch.set_num_threads(num_threads)

@pytest.mark.skipif(
Expand Down
2 changes: 2 additions & 0 deletions test/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@
pytest.mark.filterwarnings(
"ignore:The PyTorch API of nested tensors is in prototype"
),
pytest.mark.filterwarnings("ignore:unclosed event loop:ResourceWarning"),
pytest.mark.filterwarnings("ignore:unclosed.*socket:ResourceWarning"),
]


Expand Down
1 change: 1 addition & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10074,6 +10074,7 @@ def _test_vecnorm_subproc_auto(
def rename_t(self):
return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")])

@retry(AssertionError, tries=10, delay=0)
@pytest.mark.parametrize("nprc", [2, 5])
def test_vecnorm_parallel_auto(self, nprc):
queues = []
Expand Down
23 changes: 21 additions & 2 deletions torchrl/collectors/_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@
from torchrl.weight_update.utils import _resolve_model


def _cuda_sync_if_initialized():
"""Synchronize CUDA only if it has been initialized.

This is a safe alternative to calling `torch.cuda.synchronize()` directly.
In forked subprocesses on machines with CUDA, calling `synchronize()` will
fail with "Cannot re-initialize CUDA in forked subprocess" if CUDA was
initialized in the parent process before fork. By checking
`is_initialized()` first, we skip the sync in such cases since no CUDA
operations have occurred in this process.
"""
if torch.cuda.is_initialized():
torch.cuda.synchronize()


@accept_remote_rref_udf_invocation
class Collector(BaseCollector):
"""Generic data collector for RL problems. Requires an environment constructor and a policy.
Expand Down Expand Up @@ -518,9 +532,14 @@ def _setup_devices(
def _get_sync_fn(self, device: torch.device | None) -> Callable:
"""Get the appropriate synchronization function for a device."""
if device is not None and device.type != "cuda":
# Cuda handles sync
# When destination is not CUDA, we may need to sync to wait for
# async GPU→CPU transfers to complete before proceeding.
if torch.cuda.is_available():
return torch.cuda.synchronize
# Return a safe wrapper that only syncs if CUDA was actually
# initialized. This avoids "Cannot re-initialize CUDA in forked
# subprocess" errors when using fork start method on GPU machines
# with CPU-only collectors.
return _cuda_sync_if_initialized
elif torch.backends.mps.is_available() and hasattr(torch, "mps"):
return torch.mps.synchronize
elif hasattr(torch, "npu") and torch.npu.is_available():
Expand Down
Loading
Loading