diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index e125bf927ee..174e1b84deb 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -25,7 +25,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then apt-get install -y libfreetype6-dev pkg-config apt-get install -y libglfw3 libosmesa6 libglew-dev - apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb + apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg if [ "${CU_VERSION:-}" == cpu ] ; then apt-get upgrade -y libstdc++6 @@ -205,15 +205,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu + uv_pip_install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu else - uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + uv_pip_install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu + uv_pip_install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cpu else - uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION + uv_pip_install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux/scripts/run_setup_test.sh b/.github/unittest/linux/scripts/run_setup_test.sh index e95ed547a5a..a243b3a8d8c 100644 --- a/.github/unittest/linux/scripts/run_setup_test.sh +++ b/.github/unittest/linux/scripts/run_setup_test.sh @@ -66,9 +66,9 @@ else fi if [[ "$TORCH_VERSION" == "nightly" ]]; then - uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu + uv_pip_install --upgrade --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu else - uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu + uv_pip_install --upgrade torch torchvision --index-url https://download.pytorch.org/whl/cpu fi # tensordict is a hard dependency of torchrl; install it explicitly since we test diff --git a/.github/unittest/linux_libs/scripts_minari/install.sh b/.github/unittest/linux_libs/scripts_minari/install.sh index d9bbee6276f..8febf63dcdc 100755 --- a/.github/unittest/linux_libs/scripts_minari/install.sh +++ b/.github/unittest/linux_libs/scripts_minari/install.sh @@ -33,9 +33,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - uv pip install torch --index-url https://download.pytorch.org/whl/cpu + uv pip install torch --index-url https://download.pytorch.org/whl/cpu -U else - uv pip install torch --index-url https://download.pytorch.org/whl/cu128 + uv pip install torch --index-url https://download.pytorch.org/whl/cu128 -U fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_vd4rl/install.sh b/.github/unittest/linux_libs/scripts_vd4rl/install.sh index 293cf07b8bf..eb1e1c60ba1 100755 --- a/.github/unittest/linux_libs/scripts_vd4rl/install.sh +++ b/.github/unittest/linux_libs/scripts_vd4rl/install.sh @@ -37,7 +37,7 @@ elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_vmas/install.sh b/.github/unittest/linux_libs/scripts_vmas/install.sh index 8a71f05120c..1b717a16e41 100755 --- a/.github/unittest/linux_libs/scripts_vmas/install.sh +++ b/.github/unittest/linux_libs/scripts_vmas/install.sh @@ -34,9 +34,9 @@ if [[ "$TORCH_VERSION" == "nightly" ]]; then fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch --index-url https://download.pytorch.org/whl/cpu + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U else - pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 -U fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_optdeps/scripts/run_all.sh b/.github/unittest/linux_optdeps/scripts/run_all.sh index 108f52a8527..eaeb4330a9c 100755 --- a/.github/unittest/linux_optdeps/scripts/run_all.sh +++ b/.github/unittest/linux_optdeps/scripts/run_all.sh @@ -23,7 +23,7 @@ if [[ $OSTYPE != 'darwin'* ]]; then apt-get install -y vim git wget cmake apt-get install -y libglfw3 libosmesa6 libglew-dev - apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 + apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb ffmpeg if [ "${CU_VERSION:-}" == cpu ] ; then # solves version `GLIBCXX_3.4.29' not found for tensorboard diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index c899dc5f693..20f7bcc4d27 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -115,15 +115,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - uv pip install --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/cpu + uv pip install --upgrade --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/cpu else - uv pip install --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + uv pip install --upgrade --pre torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - uv pip install torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/cpu + uv pip install --upgrade torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/cpu else - uv pip install torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/$CU_VERSION + uv pip install --upgrade torch torchvision "numpy==1.26.4" --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" diff --git a/.github/unittest/windows_optdepts/scripts/unittest.sh b/.github/unittest/windows_optdepts/scripts/unittest.sh index 7241d7a53d2..4d4a6f5726f 100755 --- a/.github/unittest/windows_optdepts/scripts/unittest.sh +++ b/.github/unittest/windows_optdepts/scripts/unittest.sh @@ -71,15 +71,15 @@ python -m pip install "numpy<2.0" printf "Installing PyTorch with %s\n" "${cudatoolkit}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if $torch_cuda ; then - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 + python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 -U else python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if $torch_cuda ; then - python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 + python -m pip install torch --index-url https://download.pytorch.org/whl/cu118 -U else - python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu -U fi else printf "Failed to install pytorch" diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 5efa0592068..19deb33546f 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -22,7 +22,6 @@ from torch import nn, vmap from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator -from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv from torchrl.envs.libs.envpool import _has_envpool @@ -594,14 +593,6 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int): assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all() -def decorate_thread_sub_func(func, num_threads): - def new_func(*args, **kwargs): - assert torch.get_num_threads() == num_threads - return func(*args, **kwargs) - - return CloudpickleWrapper(new_func) - - class LSTMNet(nn.Module): """An embedder for an LSTM preceded by an MLP. diff --git a/test/test_collectors.py b/test/test_collectors.py index 1c52fb37615..7ef4d98e0a1 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -89,6 +89,7 @@ SafeModule, ) from torchrl.testing.modules import BiasModule, NonSerializableBiasModule +from torchrl.testing.mp_helpers import decorate_thread_sub_func from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -99,7 +100,6 @@ from pytorch.rl.test._utils_internal import ( CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, generate_seeds, get_available_devices, get_default_devices, @@ -112,7 +112,6 @@ from _utils_internal import ( CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, generate_seeds, get_available_devices, get_default_devices, @@ -1843,54 +1842,55 @@ def test_multi_sync_data_collector_ordering( use_buffers=True, ) - # Collect one batch - for batch in collector: - # Verify that each environment's observations match its env_id - # batch has shape [num_envs, frames_per_env] - # In the pre-emption case, we have that envs with odd ids are order of magnitude slower. - # These should be skipped by pre-emption (since they are the 50% slowest) - - # Recover rectangular shape of batch to uniform checks - if cat_results != "stack": - if not with_preempt: - batch = batch.reshape(num_envs, n_steps) - else: - traj_ids = batch["collector", "traj_ids"] - traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0 - # Split trajectories to recover correct shape - # thanks to having a single trajectory per env - # Pads with zeros! - batch = split_trajectories( - batch, trajectory_key=("collector", "traj_ids") + try: + # Collect one batch + for batch in collector: + # Verify that each environment's observations match its env_id + # batch has shape [num_envs, frames_per_env] + # In the pre-emption case, we have that envs with odd ids are order of magnitude slower. + # These should be skipped by pre-emption (since they are the 50% slowest) + + # Recover rectangular shape of batch to uniform checks + if cat_results != "stack": + if not with_preempt: + batch = batch.reshape(num_envs, n_steps) + else: + traj_ids = batch["collector", "traj_ids"] + traj_ids[traj_ids == 0] = 99 # avoid using traj_ids = 0 + # Split trajectories to recover correct shape + # thanks to having a single trajectory per env + # Pads with zeros! + batch = split_trajectories( + batch, trajectory_key=("collector", "traj_ids") + ) + # Use -1 for padding to uniform with other preemption + is_padded = batch["collector", "traj_ids"] == 0 + batch[is_padded] = -1 + + # + for env_idx in range(num_envs): + if with_preempt and env_idx % 2 == 1: + # This is a slow env, should have been preempted after first step + assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all() + continue + # This is a fast env, no preemption happened + assert (batch["collector", "traj_ids"][env_idx] != -1).all() + + env_data = batch[env_idx] + observations = env_data["observation"] + # All observations from this environment should equal its env_id + expected_id = float(env_idx) + actual_ids = observations.flatten().unique() + + assert len(actual_ids) == 1, ( + f"Env {env_idx} should only produce observations with value {expected_id}, " + f"but got {actual_ids.tolist()}" ) - # Use -1 for padding to uniform with other preemption - is_padded = batch["collector", "traj_ids"] == 0 - batch[is_padded] = -1 - - # - for env_idx in range(num_envs): - if with_preempt and env_idx % 2 == 1: - # This is a slow env, should have been preempted after first step - assert (batch["collector", "traj_ids"][env_idx, 1:] == -1).all() - continue - # This is a fast env, no preemption happened - assert (batch["collector", "traj_ids"][env_idx] != -1).all() - - env_data = batch[env_idx] - observations = env_data["observation"] - # All observations from this environment should equal its env_id - expected_id = float(env_idx) - actual_ids = observations.flatten().unique() - - assert len(actual_ids) == 1, ( - f"Env {env_idx} should only produce observations with value {expected_id}, " - f"but got {actual_ids.tolist()}" - ) - assert ( - actual_ids[0].item() == expected_id - ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" - - collector.shutdown() + assert ( + actual_ids[0].item() == expected_id + ), f"Environment {env_idx} should produce observation {expected_id}, but got {actual_ids[0].item()}" + finally: + collector.shutdown() class TestCollectorDevices: diff --git a/test/test_envs.py b/test/test_envs.py index 7c9a619e7d5..92b3f67d11f 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -179,7 +179,6 @@ def check_no_lingering_multiprocessing_resources(request): _make_envs, CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, @@ -191,7 +190,6 @@ def check_no_lingering_multiprocessing_resources(request): _make_envs, CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, @@ -3405,25 +3403,23 @@ class TestLibThreading: ) def test_num_threads(self): gc.collect() - from torchrl.envs import batched_envs - - _run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem - batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func( - batched_envs._run_worker_pipe_shared_mem, num_threads=3 - ) num_threads = torch.get_num_threads() try: - env = ParallelEnv( - 2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7 - ) + # Wrap the env factory to check thread count inside the subprocess. + # The env is created AFTER torch.set_num_threads() is called in the worker. + def make_env(): + assert ( + torch.get_num_threads() == 3 + ), f"Expected 3 threads, got {torch.get_num_threads()}" + return ContinuousActionVecMockEnv() + + env = ParallelEnv(2, make_env, num_sub_threads=3, num_threads=7) # We could test that the number of threads isn't changed until we start the procs. # Even though it's unlikely that we have 7 threads, we still disable this for safety # assert torch.get_num_threads() != 7 env.rollout(3) assert torch.get_num_threads() == 7 finally: - # reset vals - batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save torch.set_num_threads(num_threads) @pytest.mark.skipif( diff --git a/test/test_objectives.py b/test/test_objectives.py index b957216bf26..d1b354b4c0b 100644 --- a/test/test_objectives.py +++ b/test/test_objectives.py @@ -200,6 +200,8 @@ pytest.mark.filterwarnings( "ignore:The PyTorch API of nested tensors is in prototype" ), + pytest.mark.filterwarnings("ignore:unclosed event loop:ResourceWarning"), + pytest.mark.filterwarnings("ignore:unclosed.*socket:ResourceWarning"), ] diff --git a/test/test_transforms.py b/test/test_transforms.py index 0769e8be8eb..32030687ea3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -10074,6 +10074,7 @@ def _test_vecnorm_subproc_auto( def rename_t(self): return RenameTransform(in_keys=["observation"], out_keys=[("some", "obs")]) + @retry(AssertionError, tries=10, delay=0) @pytest.mark.parametrize("nprc", [2, 5]) def test_vecnorm_parallel_auto(self, nprc): queues = [] diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 8d18bafa8bf..dbb9043f67a 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -46,6 +46,20 @@ from torchrl.weight_update.utils import _resolve_model +def _cuda_sync_if_initialized(): + """Synchronize CUDA only if it has been initialized. + + This is a safe alternative to calling `torch.cuda.synchronize()` directly. + In forked subprocesses on machines with CUDA, calling `synchronize()` will + fail with "Cannot re-initialize CUDA in forked subprocess" if CUDA was + initialized in the parent process before fork. By checking + `is_initialized()` first, we skip the sync in such cases since no CUDA + operations have occurred in this process. + """ + if torch.cuda.is_initialized(): + torch.cuda.synchronize() + + @accept_remote_rref_udf_invocation class Collector(BaseCollector): """Generic data collector for RL problems. Requires an environment constructor and a policy. @@ -518,9 +532,14 @@ def _setup_devices( def _get_sync_fn(self, device: torch.device | None) -> Callable: """Get the appropriate synchronization function for a device.""" if device is not None and device.type != "cuda": - # Cuda handles sync + # When destination is not CUDA, we may need to sync to wait for + # async GPU→CPU transfers to complete before proceeding. if torch.cuda.is_available(): - return torch.cuda.synchronize + # Return a safe wrapper that only syncs if CUDA was actually + # initialized. This avoids "Cannot re-initialize CUDA in forked + # subprocess" errors when using fork start method on GPU machines + # with CPU-only collectors. + return _cuda_sync_if_initialized elif torch.backends.mps.is_available() and hasattr(torch, "mps"): return torch.mps.synchronize elif hasattr(torch, "npu") and torch.npu.is_available(): diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 8b1b3cd7cd0..3524ed2e15c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1478,7 +1478,6 @@ def look_for_cuda(tensor, has_cuda=has_cuda): env_fun = self.create_env_fn[idx] if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)): env_fun = CloudpickleWrapper(env_fun) - import torchrl kwargs[idx].update( { @@ -1489,7 +1488,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): "has_lazy_inputs": self.has_lazy_inputs, "num_threads": num_sub_threads, "non_blocking": self.non_blocking, - "filter_warnings": torchrl.filter_warnings_subprocess, + "filter_warnings": self._filter_warnings_subprocess, } ) if self._use_buffers: @@ -1525,6 +1524,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda): self.is_closed = False self.set_spec_lock_() + def _filter_warnings_subprocess(self) -> bool: + from torchrl import filter_warnings_subprocess + + return filter_warnings_subprocess + @_check_start def state_dict(self) -> OrderedDict: state_dict = OrderedDict() diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index cfe2d36101c..caa2d759a8b 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -8,6 +8,7 @@ from collections import OrderedDict from collections.abc import Callable from multiprocessing.sharedctypes import Synchronized +from multiprocessing.synchronize import Lock, RLock import torch from tensordict import TensorDictBase @@ -141,8 +142,12 @@ def meta_data(self, value: EnvMetaData): @staticmethod def _is_mp_value(val): - - return isinstance(val, (Synchronized,)) and hasattr(val, "_obj") + if isinstance(val, (Synchronized,)) and hasattr(val, "_obj"): + return True + # Also check for lock types which need to be shared across processes + if isinstance(val, (Lock, RLock)): + return True + return False @classmethod def _find_mp_values(cls, env_or_transform, values, prefix=()): diff --git a/torchrl/testing/mp_helpers.py b/torchrl/testing/mp_helpers.py new file mode 100644 index 00000000000..5a3ef90233c --- /dev/null +++ b/torchrl/testing/mp_helpers.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +import torch + +from torchrl.data.utils import CloudpickleWrapper + + +def decorate_thread_sub_func(func, num_threads): + """Decorate a function to assert that the number of threads is correct.""" + def new_func(*args, **kwargs): + assert torch.get_num_threads() == num_threads + return func(*args, **kwargs) + + return CloudpickleWrapper(new_func)