diff --git a/.github/unittest/linux/scripts/environment.yml b/.github/unittest/linux/scripts/environment.yml deleted file mode 100644 index 3283867e9bc..00000000000 --- a/.github/unittest/linux/scripts/environment.yml +++ /dev/null @@ -1,37 +0,0 @@ -channels: - - pytorch - - defaults -dependencies: - - pip - - protobuf - - pip: - - hypothesis - - future - - cloudpickle - - pygame - - moviepy<2.0.0 - - tqdm - - pytest - - pytest-cov - - pytest-mock - - pytest-instafail - - pytest-rerunfailures - - pytest-timeout - - pytest-asyncio - - expecttest - - pybind11[global] - - pyyaml - - scipy - - hydra-core - - tensorboard - - imageio==2.26.0 - - wandb - - dm_control - - mujoco<3.3.6 - - mlflow - - av - - coverage - - ray - - transformers - - ninja - - timm diff --git a/.github/unittest/linux/scripts/post_process.sh b/.github/unittest/linux/scripts/post_process.sh index e97bf2a7b1b..df82332cd84 100755 --- a/.github/unittest/linux/scripts/post_process.sh +++ b/.github/unittest/linux/scripts/post_process.sh @@ -1,6 +1,3 @@ #!/usr/bin/env bash set -e - -eval "$(./conda/bin/conda shell.bash hook)" -conda activate ./env diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 94bb8a98e09..056b0adcfad 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -6,30 +6,33 @@ set -v # =============================================================================== # # ================================ Init ========================================= # - if [[ $OSTYPE != 'darwin'* ]]; then - apt-get update && apt-get upgrade -y - apt-get install -y vim git wget cmake + # Prevent interactive prompts (notably tzdata) in CI. + export DEBIAN_FRONTEND=noninteractive + export TZ="${TZ:-Etc/UTC}" + ln -snf "/usr/share/zoneinfo/${TZ}" /etc/localtime || true + echo "${TZ}" > /etc/timezone || true + + apt-get update + apt-get install -y --no-install-recommends tzdata + dpkg-reconfigure -f noninteractive tzdata || true - # Enable universe repository - # apt-get install -y software-properties-common - # add-apt-repository universe - # apt-get update + apt-get upgrade -y + apt-get install -y vim git wget cmake curl python3-dev - # apt-get install -y libsdl2-dev libsdl2-2.0-0 + # SDL2 and freetype needed for building pygame from source (Python 3.14+) + apt-get install -y libsdl2-dev libsdl2-2.0-0 libsdl2-mixer-dev libsdl2-image-dev libsdl2-ttf-dev + apt-get install -y libfreetype6-dev pkg-config - apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev - apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 xvfb + apt-get install -y libglfw3 libosmesa6 libglew-dev + apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 xvfb if [ "${CU_VERSION:-}" == cpu ] ; then - # solves version `GLIBCXX_3.4.29' not found for tensorboard -# apt-get install -y gcc-4.9 apt-get upgrade -y libstdc++6 apt-get dist-upgrade -y else apt-get install -y g++ gcc fi - fi this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" @@ -45,38 +48,32 @@ fi # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" -conda_dir="${root_dir}/conda" -env_dir="${root_dir}/env" -lib_dir="${env_dir}/lib" +env_dir="${root_dir}/venv" cd "${root_dir}" -case "$(uname -s)" in - Darwin*) os=MacOSX;; - *) os=Linux -esac - -# 1. Install conda at ./conda -if [ ! -d "${conda_dir}" ]; then - printf "* Installing conda\n" - wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" - bash ./miniconda.sh -b -f -p "${conda_dir}" -fi -eval "$(${conda_dir}/bin/conda shell.bash hook)" - -# 2. Create test environment at ./env -printf "python: ${PYTHON_VERSION}\n" -if [ ! -d "${env_dir}" ]; then - printf "* Creating a test environment\n" - conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" -fi -conda activate "${env_dir}" - -# 3. Install Conda dependencies -printf "* Installing dependencies (except PyTorch)\n" -echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" -cat "${this_dir}/environment.yml" - +# Install uv +curl -LsSf https://astral.sh/uv/install.sh | sh +export PATH="$HOME/.local/bin:$PATH" + +# Create venv with uv +printf "* Creating venv with Python ${PYTHON_VERSION}\n" +# IMPORTANT: ensure a clean environment. +# In CI (and some local workflows), the workspace directory can be reused across runs. +# A reused venv may contain packages that violate our constraints (e.g. transformers' +# huggingface-hub upper bound), and `uv pip install` does not always guarantee +# downgrades of already-present packages unless the environment is clean. +rm -rf "${env_dir}" +uv venv --python "${PYTHON_VERSION}" "${env_dir}" +source "${env_dir}/bin/activate" +uv_pip_install() { + uv pip install --no-progress --python "${env_dir}/bin/python" "$@" +} + +# Verify CPython +python -c "import sys; assert sys.implementation.name == 'cpython', f'Expected CPython, got {sys.implementation.name}'" + +# Set environment variables if [ "${CU_VERSION:-}" == cpu ] ; then export MUJOCO_GL=glfw else @@ -84,37 +81,99 @@ else fi export SDL_VIDEODRIVER=dummy +export PYOPENGL_PLATFORM=$MUJOCO_GL +export DISPLAY=:99 +export LAZY_LEGACY_OP=False +export RL_LOGGING_LEVEL=INFO +export TOKENIZERS_PARALLELISM=true +export MAX_IDLE_COUNT=1000 +export MKL_THREADING_LAYER=GNU +export CKPT_BACKEND=torch +export BATCHED_PIPE_TIMEOUT=60 -# legacy from bash scripts: remove? -conda env config vars set \ - MAX_IDLE_COUNT=1000 \ - MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=INFO TOKENIZERS_PARALLELISM=true - -pip3 install pip --upgrade -pip install virtualenv +# ==================================================================================== # +# ================================ Install dependencies ============================== # + +printf "* Installing dependencies\n" + +# Install base dependencies +uv_pip_install \ + hypothesis \ + future \ + cloudpickle \ + packaging \ + pygame \ + "moviepy<2.0.0" \ + tqdm \ + pytest \ + pytest-cov \ + pytest-mock \ + pytest-instafail \ + pytest-rerunfailures \ + pytest-timeout \ + pytest-forked \ + pytest-asyncio \ + expecttest \ + "pybind11[global]>=2.13" \ + pyyaml \ + scipy \ + hydra-core \ + tensorboard \ + "imageio==2.26.0" \ + "huggingface-hub>=0.34.0,<1.0" \ + wandb \ + mlflow \ + av \ + coverage \ + transformers \ + ninja \ + timm + +# Install dm_control for Python < 3.13 +# labmaze (dm_control dependency) doesn't have Python 3.13+ wheels +if [[ "$PYTHON_VERSION" != "3.13" && "$PYTHON_VERSION" != "3.14" ]]; then + echo "installing dm_control" + uv_pip_install dm_control +fi -conda env update --file "${this_dir}/environment.yml" --prune +# Install ray for Python < 3.14 (ray doesn't support Python 3.14 yet) +if [[ "$PYTHON_VERSION" != "3.14" ]]; then + echo "installing ray" + uv_pip_install ray +fi -# Reset conda env variables -conda deactivate -conda activate "${env_dir}" +# Install mujoco for Python < 3.14 (mujoco doesn't have Python 3.14 wheels yet) +if [[ "$PYTHON_VERSION" != "3.14" ]]; then + echo "installing mujoco" + uv_pip_install "mujoco>=3.3.7" +fi +# Install gymnasium echo "installing gymnasium" -if [[ "$PYTHON_VERSION" == "3.12" ]]; then - pip3 install ale-py - pip3 install sympy - pip3 install "gymnasium[mujoco]>=1.1" mo-gymnasium[mujoco] +if [[ "$PYTHON_VERSION" == "3.14" ]]; then + # Python 3.14: no mujoco wheels available, ale_py also failing + uv_pip_install "gymnasium>=1.1" +elif [[ "$PYTHON_VERSION" == "3.12" ]]; then + uv_pip_install ale-py sympy + uv_pip_install "gymnasium[mujoco]>=1.1" "mo-gymnasium[mujoco]" else - pip3 install "gymnasium[atari,mujoco]>=1.1" mo-gymnasium[mujoco] + uv_pip_install "gymnasium[atari,mujoco]>=1.1" "mo-gymnasium[mujoco]" fi -# sanity check: remove? -python -c """ +# sanity check +if [[ "$PYTHON_VERSION" != "3.13" && "$PYTHON_VERSION" != "3.14" ]]; then + python -c " import dm_control from dm_control import composer from tensorboard import * from google.protobuf import descriptor as _descriptor -""" +" +else + python -c " +from tensorboard import * +from google.protobuf import descriptor as _descriptor +" +fi # ============================================================================================ # # ================================ PyTorch & TorchRL ========================================= # @@ -122,7 +181,6 @@ from google.protobuf import descriptor as _descriptor unset PYTORCH_VERSION if [ "${CU_VERSION:-}" == cpu ] ; then - version="cpu" echo "Using cpu build" else if [[ ${#CU_VERSION} -eq 4 ]]; then @@ -131,7 +189,6 @@ else CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" fi echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" - version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" fi # submodules @@ -140,15 +197,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U + uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu else - pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U + uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then - if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U + if [ "${CU_VERSION:-}" == cpu ] ; then + uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION -U + uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" @@ -158,73 +215,84 @@ fi # smoke test python -c "import functorch" -## install snapshot -#if [[ "$TORCH_VERSION" == "nightly" ]]; then -# pip3 install git+https://github.com/pytorch/torchsnapshot -#else -# pip3 install torchsnapshot -#fi +# Help CMake find pybind11 when building tensordict from source. +# pybind11 ships a CMake package; its location can be obtained via `python -m pybind11 --cmakedir`. +pybind11_DIR="$(python -m pybind11 --cmakedir)" +export pybind11_DIR # install tensordict if [[ "$RELEASE" == 0 ]]; then - pip3 install git+https://github.com/pytorch/tensordict.git + uv_pip_install --no-build-isolation git+https://github.com/pytorch/tensordict.git else - pip3 install tensordict + uv_pip_install tensordict fi printf "* Installing torchrl\n" -python -m pip install -e . --no-build-isolation - +if [[ "$RELEASE" == 0 ]]; then + uv_pip_install -e . --no-build-isolation --no-deps +else + uv_pip_install -e . --no-build-isolation +fi if [ "${CU_VERSION:-}" != cpu ] ; then printf "* Installing VC1\n" - python -c """ -from torchrl.envs.transforms.vc1 import VC1Transform -VC1Transform.install_vc_models(auto_exit=True) -""" + # Install vc_models directly via uv. + # VC1Transform.install_vc_models() uses `setup.py develop` which expects `pip` + # to be present in the environment, but uv-created venvs do not necessarily + # ship with pip. + uv_pip_install "git+https://github.com/facebookresearch/eai-vc.git#subdirectory=vc_models" printf "* Upgrading timm\n" - pip3 install --upgrade "timm>=0.9.0" + # Keep HF Hub constrained: timm can pull a hub>=1.x which breaks transformers' + # import-time version check. + uv_pip_install --upgrade "timm>=0.9.0" "huggingface-hub>=0.34.0,<1.0" - python -c """ + python -c " import vc_models from vc_models.models.vit import model_utils print(model_utils) -""" +" fi # ==================================================================================== # # ================================ Run tests ========================================= # - export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env -## Avoid error: "fatal: unsafe repository" -#git config --global --add safe.directory '*' -#root_dir="$(git rev-parse --show-toplevel)" - -# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found -#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir -export MKL_THREADING_LAYER=GNU -export CKPT_BACKEND=torch -export MAX_IDLE_COUNT=100 -export BATCHED_PIPE_TIMEOUT=60 Xvfb :99 -screen 0 1024x768x24 & pytest test/smoke_test.py -v --durations 200 pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_control_pixels or test_dm_control or test_tb' + +# Track if any tests fail +EXIT_STATUS=0 + +# Run distributed tests first (GPU only) to surface errors early +if [ "${CU_VERSION:-}" != cpu ] ; then + python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_distributed.py \ + --instafail --durations 200 -vv --capture no \ + --timeout=120 --mp_fork_if_no_cuda || EXIT_STATUS=$? +fi + +# Run remaining tests (always run even if distributed tests failed) if [ "${CU_VERSION:-}" != cpu ] ; then python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \ + --ignore test/test_distributed.py \ --ignore test/llm \ - --timeout=120 --mp_fork_if_no_cuda + --timeout=120 --mp_fork_if_no_cuda || EXIT_STATUS=$? else python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \ --instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \ --ignore test/test_distributed.py \ --ignore test/llm \ - --timeout=120 --mp_fork_if_no_cuda + --timeout=120 --mp_fork_if_no_cuda || EXIT_STATUS=$? +fi + +# Fail the workflow if any tests failed +if [ $EXIT_STATUS -ne 0 ]; then + echo "Some tests failed with exit status $EXIT_STATUS" fi coverage combine @@ -234,3 +302,6 @@ coverage xml -i # ================================ Post-proc ========================================= # bash ${this_dir}/post_process.sh + +# Exit with failure if any tests failed +exit $EXIT_STATUS diff --git a/.github/unittest/linux_optdeps/scripts/run_all.sh b/.github/unittest/linux_optdeps/scripts/run_all.sh index c73efc50834..108f52a8527 100755 --- a/.github/unittest/linux_optdeps/scripts/run_all.sh +++ b/.github/unittest/linux_optdeps/scripts/run_all.sh @@ -9,11 +9,21 @@ set -e if [[ $OSTYPE != 'darwin'* ]]; then - apt-get update && apt-get upgrade -y + # Prevent interactive prompts (notably tzdata) in CI. + export DEBIAN_FRONTEND=noninteractive + export TZ="${TZ:-Etc/UTC}" + ln -snf "/usr/share/zoneinfo/${TZ}" /etc/localtime || true + echo "${TZ}" > /etc/timezone || true + + apt-get update + apt-get install -y --no-install-recommends tzdata + dpkg-reconfigure -f noninteractive tzdata || true + + apt-get upgrade -y apt-get install -y vim git wget cmake - apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev - apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 + apt-get install -y libglfw3 libosmesa6 libglew-dev + apt-get install -y libglvnd0 libgl1 libglx0 libglx-mesa0 libegl1 libgles2 if [ "${CU_VERSION:-}" == cpu ] ; then # solves version `GLIBCXX_3.4.29' not found for tensorboard diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index cf1a9be33a8..a81c562a446 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -7,7 +7,16 @@ set -v # ================================ Init ============================================== # -apt-get update && apt-get upgrade -y +export DEBIAN_FRONTEND=noninteractive +export TZ="${TZ:-Etc/UTC}" +ln -snf "/usr/share/zoneinfo/${TZ}" /etc/localtime || true +echo "${TZ}" > /etc/timezone || true + +apt-get update +apt-get install -y --no-install-recommends tzdata +dpkg-reconfigure -f noninteractive tzdata || true + +apt-get upgrade -y apt-get install -y vim git wget cmake apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libosmesa6-dev diff --git a/.github/workflows/nightly_build.yml b/.github/workflows/nightly_build.yml index 5e6c6a3bb91..8069b6a689d 100644 --- a/.github/workflows/nightly_build.yml +++ b/.github/workflows/nightly_build.yml @@ -42,11 +42,11 @@ jobs: matrix: os: [['linux', 'ubuntu-22.04'], ['macos', 'macos-latest']] python_version: [ - ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"], ["3.13", "cp313-cp313"], + ["3.14", "cp314-cp314"], ] cuda_support: [["", "cpu", "cpu"]] steps: @@ -88,11 +88,11 @@ jobs: matrix: os: [['linux', 'ubuntu-22.04'], ['macos', 'macos-latest']] python_version: [ - ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"], ["3.13", "cp313-cp313"], + ["3.14", "cp314-cp314"], ] cuda_support: [["", "cpu", "cpu"]] steps: @@ -162,11 +162,11 @@ jobs: matrix: os: [['linux', 'ubuntu-22.04'], ['macos', 'macos-latest']] python_version: [ - ["3.9", "cp39-cp39"], ["3.10", "cp310-cp310"], ["3.11", "cp311-cp311"], ["3.12", "cp312-cp312"], ["3.13", "cp313-cp313"], + ["3.14", "cp314-cp314"], ] cuda_support: [["", "cpu", "cpu"]] steps: @@ -204,11 +204,11 @@ jobs: strategy: matrix: python_version: [ - ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"], ["3.13", "3.13"], + ["3.14", "3.14"], ] steps: - name: Setup Python @@ -244,11 +244,11 @@ jobs: strategy: matrix: python_version: [ - ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"], ["3.13", "3.13"], + ["3.14", "3.14"], ] steps: - name: Setup Python @@ -314,11 +314,11 @@ jobs: strategy: matrix: python_version: [ - ["3.9", "3.9"], ["3.10", "3.10.3"], ["3.11", "3.11"], ["3.12", "3.12"], ["3.13", "3.13"], + ["3.14", "3.14"], ] steps: - name: Checkout torchrl diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index e660f29ba78..9bfec8b4455 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -26,14 +26,14 @@ jobs: tests-cpu: strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.10", "3.11", "3.12", "3.13", "3.14"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.12xlarge repository: pytorch/rl - docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04" - timeout: 90 + docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04" + timeout: 120 script: | if [[ "${{ github.ref }}" =~ release/* ]]; then export RELEASE=1 @@ -56,17 +56,17 @@ jobs: tests-gpu: strategy: matrix: - python_version: ["3.11"] - cuda_arch_version: ["12.8"] + python_version: ["3.12"] + cuda_arch_version: ["13.0"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.g5.4xlarge.nvidia.gpu repository: pytorch/rl - docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04" gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} - timeout: 90 + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} @@ -104,7 +104,7 @@ jobs: docker-image: "nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04" gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} - timeout: 90 + timeout: 120 script: | set -euo pipefail export PYTHON_VERSION="3.9" @@ -128,17 +128,17 @@ jobs: tests-optdeps: strategy: matrix: - python_version: ["3.11"] - cuda_arch_version: ["12.8"] + python_version: ["3.12"] + cuda_arch_version: ["13.0"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.g5.4xlarge.nvidia.gpu repository: pytorch/rl - docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04" gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} - timeout: 90 + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} @@ -163,17 +163,17 @@ jobs: tests-stable-gpu: strategy: matrix: - python_version: ["3.10"] # "3.9", "3.10", "3.11" - cuda_arch_version: ["11.8"] # "11.6", "11.7" + python_version: ["3.12"] # "3.9", "3.10", "3.11" + cuda_arch_version: ["13.0"] # "11.6", "11.7" fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: runner: linux.g5.4xlarge.nvidia.gpu repository: pytorch/rl - docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04" gpu-arch-type: cuda gpu-arch-version: ${{ matrix.cuda_arch_version }} - timeout: 90 + timeout: 120 script: | # Set env vars from matrix export PYTHON_VERSION=${{ matrix.python_version }} diff --git a/pyproject.toml b/pyproject.toml index 0bc91eb32e5..7d1de1b5745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,11 +16,11 @@ maintainers = [ ] keywords = ["reinforcement-learning", "pytorch", "rl", "machine-learning"] classifiers = [ - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Operating System :: OS Independent", "Development Status :: 4 - Beta", "Intended Audience :: Developers", @@ -48,6 +48,7 @@ tests = [ "pyyaml", "pytest-instafail", "scipy", + "psutil", "pytest-mock", "pytest-cov", "pytest-asyncio", @@ -55,6 +56,9 @@ tests = [ "pytest-rerunfailures", "pytest-error-for-skips", "pytest-timeout", + "pytest-forked", + "pytest-random-order", + "pytest-repeat", ] utils = [ "tensorboard", diff --git a/test/smoke_test_deps.py b/test/smoke_test_deps.py index ade1c68f4af..941107199fe 100644 --- a/test/smoke_test_deps.py +++ b/test/smoke_test_deps.py @@ -6,11 +6,16 @@ import argparse import os +import sys import tempfile import pytest +@pytest.mark.skipif( + sys.version_info >= (3, 13), + reason="dm_control not available on Python 3.13+ (labmaze lacks wheels)", +) def test_dm_control(): import dm_control # noqa: F401 import dm_env # noqa: F401 @@ -23,21 +28,29 @@ def test_dm_control(): env.reset() +@pytest.mark.skipif( + sys.version_info >= (3, 13), + reason="dm_control not available on Python 3.13+ (labmaze lacks wheels)", +) @pytest.mark.skip(reason="Not implemented yet") def test_dm_control_pixels(): - from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv # noqa + from torchrl.envs.libs.dm_control import DMControlEnv env = DMControlEnv("cheetah", "run", from_pixels=True) env.reset() +@pytest.mark.skipif( + sys.version_info >= (3, 14), + reason="gymnasium[atari] / ALE not available on Python 3.14 in CI (ale-py install failing)", +) def test_gym(): try: import gymnasium as gym except ImportError as err: ERROR = err try: - import gym # noqa: F401 + import gym as gym # noqa: F401 except ImportError as err: raise ImportError( f"gym and gymnasium load failed. Gym got error {err}." @@ -46,12 +59,30 @@ def test_gym(): from torchrl.envs.libs.gym import _has_gym, GymEnv # noqa assert _has_gym + # If gymnasium is installed without the atari extra, ALE won't be registered. + # In that case we skip rather than hard-failing the dependency smoke test. + try: + import ale_py # noqa: F401 + except Exception: # pragma: no cover + pytest.skip("ALE not available (missing ale_py); skipping Atari gym test.") if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import PONG_VERSIONED else: from _utils_internal import PONG_VERSIONED - env = GymEnv(PONG_VERSIONED()) + try: + env = GymEnv(PONG_VERSIONED()) + except Exception as err: # gymnasium.error.NamespaceNotFound and similar + namespace_not_found = err.__class__.__name__ == "NamespaceNotFound" + if hasattr(gym, "error") and hasattr(gym.error, "NamespaceNotFound"): + namespace_not_found = namespace_not_found or isinstance( + err, gym.error.NamespaceNotFound + ) + if namespace_not_found: + pytest.skip( + "ALE namespace not registered (gymnasium installed without atari extra)." + ) + raise env.reset() diff --git a/test/test_distributed.py b/test/test_distributed.py index 6f6326eaf4b..c6dc3ffde7e 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -42,16 +42,22 @@ SamplerWithoutReplacement, ) from torchrl.modules import RandomPolicy - -_has_ray = importlib.util.find_spec("ray") is not None +from torchrl.testing.dist_utils import ( + assert_no_new_python_processes, + snapshot_python_processes, +) from torchrl.testing.mocking_classes import ContinuousActionVecMockEnv, CountingEnv +_has_ray = importlib.util.find_spec("ray") is not None + TIMEOUT = 200 if sys.platform.startswith("win"): pytest.skip("skipping windows tests in windows", allow_module_level=True) +pytestmark = [pytest.mark.forked] + class CountingPolicy(TensorDictModuleBase): """A policy for counting env. @@ -221,6 +227,81 @@ def test_distributed_collector_sync(self, sync): proc.terminate() queue.close() + @classmethod + def _test_distributed_collector_updatepolicy_shutdown_only(cls, queue, sync): + """Small rollout + weight sync + shutdown (used for leak checks in parent process).""" + collector = None + try: + frames_per_batch = 50 + total_frames = 250 + env = CountingEnv + policy = CountingPolicy() + dcls = cls.distributed_class() + collector = dcls( + [env] * 2, + policy, + collector_class=Collector, + total_frames=total_frames, + frames_per_batch=frames_per_batch, + sync=sync, + **cls.distributed_kwargs(), + ) + first_batch = None + seen_updated = False + total = 0 + for i, data in enumerate(collector): + total += data.numel() + if i == 0: + first_batch = data + policy.weight.data.add_(1) + collector.update_policy_weights_(policy) + else: + if (data["action"] == 2).all(): + seen_updated = True + assert total == total_frames + assert first_batch is not None + assert (first_batch["action"] == 1).all(), first_batch["action"] + assert ( + seen_updated + ), "Updated weights were never observed in collected batches." + queue.put(("passed", None)) + except Exception as e: + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) + finally: + if collector is not None: + collector.shutdown() + + @pytest.mark.parametrize("sync", [False, True]) + def test_collector_shutdown_clears_python_processes(self, sync): + """Regression test: collector.shutdown() should not leak python processes.""" + queue = mp.Queue(1) + # Creating multiprocessing primitives (Queue / SemLock) may spawn Python's + # `multiprocessing.resource_tracker` helper process. That process is not owned + # by the collector and may live for the duration of the test runner, so we + # include it in the baseline. + baseline = snapshot_python_processes() + baseline_time = time.time() + + proc = mp.Process( + target=self._test_distributed_collector_updatepolicy_shutdown_only, + args=(queue, sync), + ) + proc.start() + try: + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] + finally: + proc.join(10) + if proc.is_alive(): + proc.terminate() + queue.close() + + assert_no_new_python_processes( + baseline=baseline, baseline_time=baseline_time, timeout=20.0 + ) + @classmethod def _test_distributed_collector_class(cls, queue, collector_class): try: @@ -318,14 +399,18 @@ def _test_distributed_collector_updatepolicy( if i == 0: first_batch = data if policy is not None: - policy.weight.data += 1 + # Avoid using `.data` (and avoid tracking in autograd). + policy.weight.data.add_(1) else: + assert weights is not None weights.data += 1 torchrl_logger.info("TEST -- Calling update_policy_weights_()") collector.update_policy_weights_(weights) torchrl_logger.info("TEST -- Done calling update_policy_weights_()") elif total == total_frames - frames_per_batch: last_batch = data + assert first_batch is not None + assert last_batch is not None assert (first_batch["action"] == 1).all(), first_batch["action"] assert (last_batch["action"] == 2).all(), last_batch["action"] collector.shutdown() @@ -421,6 +506,10 @@ def _start_worker(cls): def test_distributed_collector_sync(self, *args): raise pytest.skip("skipping as only sync is supported") + @pytest.mark.parametrize("sync", [True]) + def test_collector_shutdown_clears_python_processes(self, sync): + super().test_collector_shutdown_clears_python_processes(sync) + @classmethod def _test_distributed_collector_updatepolicy( cls, @@ -456,9 +545,11 @@ def _test_distributed_collector_updatepolicy( assert data.numel() == frames_per_batch if i == 0: first_batch = data - policy.weight.data += 1 + policy.weight.data.add_(1) elif total == total_frames - frames_per_batch: last_batch = data + assert first_batch is not None + assert last_batch is not None assert (first_batch["action"] == 1).all(), first_batch["action"] if update_interval == 1: assert (last_batch["action"] == 2).all(), last_batch["action"] @@ -517,7 +608,16 @@ def start_ray(self): import ray from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG - ray.init(**DEFAULT_RAY_INIT_CONFIG) + # Ensure Ray is initialized with a runtime_env that lets workers import + # this test module (e.g. `CountingPolicy`), otherwise actor unpickling can + # fail with "No module named 'test_distributed'". + ray.shutdown() + ray_init_config = dict(DEFAULT_RAY_INIT_CONFIG) + ray_init_config["runtime_env"] = { + "working_dir": os.path.dirname(__file__), + "env_vars": {"PYTHONPATH": os.path.dirname(__file__)}, + } + ray.init(**ray_init_config) yield ray.shutdown() @@ -538,14 +638,13 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - import ray - - ray.shutdown() # make sure ray is not running - ray_init_config = DEFAULT_RAY_INIT_CONFIG + # Ray will be auto-initialized by RayCollector if not already started. + # We need to provide runtime_env so workers can import this test module. + ray_init_config = dict(DEFAULT_RAY_INIT_CONFIG) ray_init_config["runtime_env"] = { "working_dir": os.path.dirname(__file__), "env_vars": {"PYTHONPATH": os.path.dirname(__file__)}, - } # for ray workers + } remote_configs = { "num_cpus": 1, "num_gpus": 0.0, @@ -579,6 +678,67 @@ def test_distributed_collector_sync(self, sync, frames_per_batch=200): finally: collector.shutdown() + @pytest.mark.parametrize("sync", [False, True]) + def test_collector_shutdown_clears_python_processes(self, sync): + """Regression test: collector.shutdown() should not leak python processes (ray).""" + kwargs = self.distributed_kwargs() + baseline = snapshot_python_processes() + baseline_time = time.time() + + frames_per_batch = 50 + total_frames = 250 + env = CountingEnv + policy = CountingPolicy() + collector = self.distributed_class()( + [env] * 2, + policy, + collector_class=Collector, + total_frames=total_frames, + frames_per_batch=frames_per_batch, + sync=sync, + **kwargs, + ) + try: + total = 0 + first_batch = None + seen_updated = False + for i, data in enumerate(collector): + total += data.numel() + if i == 0: + first_batch = data + policy.weight.data.add_(1) + collector.update_policy_weights_(policy) + else: + if (data["action"] == 2).all(): + seen_updated = True + assert total == total_frames + assert first_batch is not None + assert (first_batch["action"] == 1).all(), first_batch["action"] + assert ( + seen_updated + ), "Updated weights were never observed in collected batches." + finally: + collector.shutdown() + + def _is_ray_runtime_proc(info): + args = info.get("args") or "" + comm = info.get("comm") or "" + return ( + " ray::" in args.lower() + or "/site-packages/ray/" in args + or comm in {"raylet", "gcs_server"} + ) + + assert_no_new_python_processes( + baseline=baseline, + baseline_time=baseline_time, + timeout=30.0, + # Ray's core daemons and prestarted workers can legitimately outlive a + # collector. We only want to catch leaked *non-Ray* Python processes + # spawned by the collector itself. + ignore_info_fn=_is_ray_runtime_proc, + ) + @pytest.mark.parametrize( "collector_class", [ @@ -658,12 +818,15 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync, pfactor if i == 0: first_batch = data if policy is not None: - policy.weight.data += 1 + policy.weight.data.add_(1) else: - weights.data += 1 + assert weights is not None + weights.data.add_(1) collector.update_policy_weights_(weights) elif total == total_frames - frames_per_batch: last_batch = data + assert first_batch is not None + assert last_batch is not None assert (first_batch["action"] == 1).all(), first_batch["action"] assert (last_batch["action"] == 2).all(), last_batch["action"] assert total == total_frames @@ -730,8 +893,10 @@ def policy_constructor(): p = policy_constructor() # p(env().reset()) weights = TensorDict.from_module(p) - weights["module", "1", "module", "weight"].data.fill_(0) - weights["module", "1", "module", "bias"].data.fill_(2) + # `TensorDict.__getitem__` returns tensors; use in-place ops directly. + with torch.no_grad(): + weights["module", "1", "module", "weight"].fill_(0) + weights["module", "1", "module", "bias"].fill_(2) collector.update_policy_weights_(weights) try: for data in collector: diff --git a/test/test_env.py b/test/test_env.py index f79efcba1df..9572c591613 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -13,6 +13,8 @@ import pickle import random import re +import signal +import threading import time from collections import defaultdict from functools import partial @@ -98,6 +100,59 @@ pytest.mark.filterwarnings("ignore:unclosed file"), ] + +@pytest.fixture(autouse=False) # Turn to True to enable +def check_no_lingering_multiprocessing_resources(request): + """Fixture that checks for leftover multiprocessing resources after each test. + + This helps detect test pollution where one test leaves behind resource_sharer + threads, zombie processes, or other multiprocessing state that can cause + deadlocks in subsequent tests (especially with fork start method on Linux). + + See: https://bugs.python.org/issue30289 + """ + # Record state before test + threads_before = {t.name for t in threading.enumerate()} + # Count resource_sharer threads specifically + resource_sharer_before = sum( + 1 + for t in threading.enumerate() + if "_serve" in t.name or "resource_sharer" in t.name.lower() + ) + + yield + + # Give a brief moment for cleanup + gc.collect() + time.sleep(0.05) + + # Check for new resource_sharer threads + resource_sharer_after = sum( + 1 + for t in threading.enumerate() + if "_serve" in t.name or "resource_sharer" in t.name.lower() + ) + + # Only warn (not fail) for now - this is informational to help debug + if resource_sharer_after > resource_sharer_before: + new_threads = {t.name for t in threading.enumerate()} - threads_before + resource_sharer_threads = [ + t.name + for t in threading.enumerate() + if "_serve" in t.name or "resource_sharer" in t.name.lower() + ] + import warnings + + warnings.warn( + f"Test {request.node.name} left behind {resource_sharer_after - resource_sharer_before} " + f"resource_sharer thread(s): {resource_sharer_threads}. " + f"New threads: {new_threads}. " + "This can cause deadlocks in subsequent tests with fork start method.", + UserWarning, + stacklevel=1, + ) + + gym_version = None if _has_gym: try: @@ -3797,6 +3852,7 @@ def test_serial(self, bwad, use_buffers): r = env.rollout(N, break_when_any_done=bwad) assert r.get("non_tensor").tolist() == [list(range(N))] * 2 + # @pytest.mark.forked # Run in isolated subprocess to avoid resource_sharer pollution from other tests @pytest.mark.parametrize("bwad", [True, False]) @pytest.mark.parametrize("use_buffers", [False, True]) def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv): @@ -3811,6 +3867,56 @@ def test_parallel(self, bwad, use_buffers, maybe_fork_ParallelEnv): time.sleep(0.1) gc.collect() + @pytest.mark.skipif( + platform == "win32", reason="signal-based timeout not supported." + ) + def test_parallel_large_non_tensor_does_not_deadlock(self, maybe_fork_ParallelEnv): + """Regression test: large non-tensor payloads must not deadlock ParallelEnv in buffer mode. + + In shared-buffer mode, non-tensor leaves are sent over the Pipe. If the worker + blocks on `send()` (pipe buffer full) before setting its completion event, + the parent can hang forever waiting for that event. We guard against this by + using a signal alarm and a very large non-tensor payload. + """ + + class _LargeNonTensorEnv(EnvWithMetadata): + def __init__(self, payload_size: int = 5_000_000): + super().__init__() + self._payload = b"x" * payload_size + + def _reset(self, tensordict): + data = self._saved_obs_spec.zero() + data.set_non_tensor("non_tensor", self._payload) + data.update(self.full_done_spec.zero()) + return data + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + data = self._saved_obs_spec.zero() + data.set_non_tensor("non_tensor", self._payload) + data.update(self.full_done_spec.zero()) + data.update(self._saved_full_reward_spec.zero()) + return data + + def _alarm_handler(signum, frame): + raise TimeoutError( + "ParallelEnv deadlocked while waiting for workers with large non-tensor payloads." + ) + + old_handler = signal.signal(signal.SIGALRM, _alarm_handler) + signal.alarm(15) + env = maybe_fork_ParallelEnv(2, _LargeNonTensorEnv, use_buffers=True) + try: + td = env.reset() + td = td.set("action", torch.zeros(2, 1)) + _ = env.step(td) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + env.close(raise_if_closed=False) + del env + time.sleep(0.1) + gc.collect() + class AddString(Transform): def __init__(self): super().__init__() diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 5f2211ea798..323116c6995 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -300,6 +300,13 @@ def _run_collector( elif instruction == b"shutdown": if verbose: torchrl_logger.debug(f"RANK {rank} -- shutting down") + # Shutdown weight sync schemes first (stops background threads) + if weight_sync_schemes is not None: + for scheme in weight_sync_schemes.values(): + try: + scheme.shutdown() + except Exception: + pass try: collector.shutdown() except Exception: @@ -569,6 +576,7 @@ def __init__( backend: str = "gloo", update_after_each_batch: bool = False, max_weight_update_interval: int = -1, + update_interval: int | None = None, launcher: str = "submitit", tcp_port: int | None = None, weight_updater: WeightUpdaterBase @@ -617,6 +625,12 @@ def __init__( self._sync = sync self.update_after_each_batch = update_after_each_batch self.max_weight_update_interval = max_weight_update_interval + if update_interval is not None and update_interval < 1: + raise ValueError( + "`update_interval` must be >= 1 when provided. " + f"Got update_interval={update_interval}." + ) + self.update_interval = update_interval if self.update_after_each_batch and self.max_weight_update_interval > -1: raise RuntimeError( "Got conflicting update instructions: `update_after_each_batch` " @@ -992,6 +1006,7 @@ def _iterator_dist(self): torchrl_logger.debug("RANK 0 -- iterating...") total_frames = 0 + num_batches_yielded = 0 if not self._sync: for rank in range(1, self.num_workers + 1): torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") @@ -1009,13 +1024,48 @@ def _iterator_dist(self): if self._sync: data, total_frames = self._next_sync(total_frames) else: - data, total_frames = self._next_async(total_frames, trackers) + data, total_frames, ready_worker_idx = self._next_async( + total_frames, trackers + ) if self.split_trajs: data = split_trajectories(data) if self.postproc is not None: data = self.postproc(data) yield data + num_batches_yielded += 1 + has_more = total_frames < self.total_frames + + # Automatic weight update hook: update_interval controls how often we + # propagate weights through the registered weight sync schemes. + # + # Important: for async collection, we do this *after* yielding the batch + # (so the user can mutate policy weights) but *before* letting the worker + # continue, to ensure the next batch reflects the new weights. + if ( + has_more + and self.update_interval is not None + and self._weight_sync_schemes is not None + and num_batches_yielded % self.update_interval == 0 + ): + if self._sync: + # Sync case: all workers will proceed next, update everyone. + for scheme in self._weight_sync_schemes.values(): + scheme.send() + else: + # Async case: only release the worker that just produced data. + for scheme in self._weight_sync_schemes.values(): + scheme.send(worker_ids=ready_worker_idx) + + if (not self._sync) and has_more: + # Release the worker that produced the last batch and restart its + # receive tracker *after* any weight update has been propagated. + rank = ready_worker_idx + 1 + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") + self._store.set(f"NODE_{rank}_in", b"continue") + trackers[ready_worker_idx] = self._tensordict_out[ + ready_worker_idx + ].irecv(src=rank, return_premature=True) if self.max_weight_update_interval > -1: for j in range(self.num_workers): @@ -1070,6 +1120,7 @@ def _next_sync(self, total_frames): def _next_async(self, total_frames, trackers): data = None + ready_worker_idx = None while data is None: for i in range(self.num_workers): rank = i + 1 @@ -1086,16 +1137,15 @@ def _next_async(self, total_frames, trackers): ) self.update_policy_weights_(worker_ids=rank) total_frames += data.numel() - if total_frames < self.total_frames: - torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") - self._store.set(f"NODE_{rank}_in", b"continue") - trackers[i] = self._tensordict_out[i].irecv( - src=i + 1, return_premature=True - ) + ready_worker_idx = i for j in range(self.num_workers): self._batches_since_weight_update[j] += j != i break - return data, total_frames + if ready_worker_idx is None: + raise RuntimeError( + "Failed to find a ready worker in async collection loop." + ) + return data, total_frames, ready_worker_idx def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): @@ -1117,6 +1167,11 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: raise NotImplementedError def shutdown(self, timeout: float | None = None) -> None: + # Prevent double shutdown + if getattr(self, "_shutdown", False): + return + self._shutdown = True + self._store.set("TRAINER_status", b"shutdown") for i in range(self.num_workers): rank = i + 1 @@ -1138,6 +1193,25 @@ def shutdown(self, timeout: float | None = None) -> None: self.jobs[i].result() elif self.launcher == "submitit_delayed": pass + + # Clean up weight sync schemes AFTER workers have exited + # (workers have their own scheme instances that they clean up) + if self._weight_sync_schemes is not None: + torchrl_logger.debug("shutting down weight sync schemes") + for scheme in self._weight_sync_schemes.values(): + try: + scheme.shutdown() + except Exception as e: + torchrl_logger.warning( + f"Error shutting down weight sync scheme: {e}" + ) + self._weight_sync_schemes = None + + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torchrl_logger.debug("destroying process group") + torch.distributed.destroy_process_group() + torchrl_logger.debug("collector shut down") diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 2df0bca48e1..2f8930207f6 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -1071,6 +1071,19 @@ def shutdown( timeout=timeout if timeout is not None else 5.0 ) self.stop_remote_collectors() + + # Clean up weight sync schemes AFTER workers have exited + if getattr(self, "_weight_sync_schemes", None) is not None: + torchrl_logger.debug("shutting down weight sync schemes") + for scheme in self._weight_sync_schemes.values(): + try: + scheme.shutdown() + except Exception as e: + torchrl_logger.warning( + f"Error shutting down weight sync scheme: {e}" + ) + self._weight_sync_schemes = None + if shutdown_ray: ray.shutdown() diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index a5d0c9a7140..9cc5bbe8076 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -863,6 +863,7 @@ def shutdown(self, timeout: float | None = None) -> None: return if self._shutdown: return + torchrl_logger.debug("shutting down") for future, i in self.futures: # clear the futures @@ -876,10 +877,6 @@ def shutdown(self, timeout: float | None = None) -> None: torchrl_logger.debug("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) - # Destroy torch.distributed process group - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - if self.launcher == "mp": for job in self.jobs: job.join(int(IDLE_TIMEOUT)) @@ -890,6 +887,23 @@ def shutdown(self, timeout: float | None = None) -> None: pass else: raise NotImplementedError(f"Unknown launcher {self.launcher}") + + # Clean up weight sync schemes AFTER workers have exited + if getattr(self, "_weight_sync_schemes", None) is not None: + torchrl_logger.debug("shutting down weight sync schemes") + for scheme in self._weight_sync_schemes.values(): + try: + scheme.shutdown() + except Exception as e: + torchrl_logger.warning( + f"Error shutting down weight sync scheme: {e}" + ) + self._weight_sync_schemes = None + + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + self._shutdown = True diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index a63b6d33c66..33a26220451 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -686,12 +686,42 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: raise NotImplementedError def shutdown(self, timeout: float | None = None) -> None: - # Clean up weight sync schemes + # Prevent double shutdown + if getattr(self, "_shutdown", False): + return + self._shutdown = True + + # Wait for workers to exit + if hasattr(self, "jobs"): + for job in self.jobs: + if self.launcher == "mp": + if hasattr(job, "is_alive") and job.is_alive(): + job.join(timeout=timeout if timeout is not None else 10) + elif self.launcher == "submitit": + try: + job.result() + except Exception: + pass + + # Clean up weight sync schemes AFTER workers have exited if self._weight_sync_schemes is not None: + torchrl_logger.debug("shutting down weight sync schemes") for scheme in self._weight_sync_schemes.values(): - scheme.shutdown() + try: + scheme.shutdown() + except Exception as e: + torchrl_logger.warning( + f"Error shutting down weight sync scheme: {e}" + ) self._weight_sync_schemes = None + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torchrl_logger.debug("destroying process group") + torch.distributed.destroy_process_group() + + torchrl_logger.debug("collector shut down") + class DistributedSyncDataCollector( DistributedSyncCollector, metaclass=_LegacyCollectorMeta diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 0ba2c019303..d2033b2de14 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -15,6 +15,7 @@ from copy import deepcopy from functools import wraps from multiprocessing import connection +from multiprocessing.connection import wait as connection_wait from multiprocessing.synchronize import Lock as MpLock from typing import Any from warnings import warn @@ -1566,7 +1567,11 @@ def _step_and_maybe_reset_no_buffers( if self.consolidate: try: td = tensordict.consolidate( - share_memory=True, inplace=True, num_threads=1 + # share_memory=False: avoid resource_sharer which causes + # progressive slowdown with fork on Linux + share_memory=False, + inplace=True, + num_threads=1, ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err @@ -1802,34 +1807,78 @@ def select_and_transfer(x, y): return tensordict, tensordict_ def _wait_for_workers(self, workers_range): - workers_range_consume = set(workers_range) + """Wait for all workers to signal completion via their events. + + Uses multiprocessing.connection.wait() for efficient OS-level + waiting on multiple pipes simultaneously. + """ + timeout = self.BATCHED_PIPE_TIMEOUT t0 = time.time() - while ( - len(workers_range_consume) - and (time.time() - t0) < self.BATCHED_PIPE_TIMEOUT - ): - for i in workers_range: - if i not in workers_range_consume: - continue - worker = self._workers[i] - if worker.is_alive(): - event: mp.Event = self._events[i] - if event.is_set(): - workers_range_consume.discard(i) - event.clear() - else: - continue - else: - try: - self._shutdown_workers() - finally: - raise RuntimeError(f"Cannot proceed, worker {i} dead.") - # event.wait(self.BATCHED_PIPE_TIMEOUT) - if len(workers_range_consume): - raise RuntimeError( - f"Failed to run all workers within the {self.BATCHED_PIPE_TIMEOUT} sec time limit. This " - f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable." - ) + + # In shared-memory/buffer mode, workers signal completion by setting + # their `mp_event` (they may not send anything back on the pipe). + if self._use_buffers: + pending = set(workers_range) + n_iter = 0 + while pending: + n_iter += 1 + remaining = timeout - (time.time() - t0) + if remaining <= 0: + raise RuntimeError( + f"Failed to run all workers within the {timeout} sec time limit. This " + f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable." + ) + + # Wait in short slices so we can both harvest multiple events and + # periodically check for dead workers without blocking forever. + slice_timeout = min(0.1, remaining) + progressed = False + for wi in tuple(pending): + if self._events[wi].wait(timeout=slice_timeout): + self._events[wi].clear() + pending.remove(wi) + progressed = True + + if not progressed and (n_iter % 50) == 0: + for wi in pending: + if not self._workers[wi].is_alive(): + try: + self._shutdown_workers() + finally: + raise RuntimeError(f"Cannot proceed, worker {wi} dead.") + return + + # No-buffer mode: workers send back data on the pipe, so we can efficiently + # block on readability. + pipes_pending = {self.parent_channels[i]: i for i in workers_range} + i = 0 + while pipes_pending: + i += 1 + should_check_for_dead_workers = (i % 20) == 0 + remaining = timeout - (time.time() - t0) + if remaining <= 0: + raise RuntimeError( + f"Failed to run all workers within the {timeout} sec time limit. This " + f"threshold can be increased via the BATCHED_PIPE_TIMEOUT env variable." + ) + + # Wait for any pipes to become readable (OS-level select/poll) + ready = connection_wait(list(pipes_pending.keys()), timeout=remaining) + + if not ready and should_check_for_dead_workers: + # Timeout with no pipes ready - check for dead workers + for wi in pipes_pending.values(): + if not self._workers[wi].is_alive(): + try: + self._shutdown_workers() + finally: + raise RuntimeError(f"Cannot proceed, worker {wi} dead.") + continue + + # Clear events for ready workers (best-effort) + for pipe in ready: + wi = pipes_pending.pop(pipe) + self._events[wi].clear() def _step_no_buffers( self, tensordict: TensorDictBase @@ -1848,7 +1897,11 @@ def _step_no_buffers( if self.consolidate: try: data = tensordict.consolidate( - share_memory=True, inplace=False, num_threads=1 + # share_memory=False: avoid resource_sharer which causes + # progressive slowdown with fork on Linux + share_memory=False, + inplace=False, + num_threads=1, ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err @@ -1867,8 +1920,6 @@ def _step_no_buffers( else: local_data = local_data.to(env_device) self.parent_channels[i].send(("step", local_data)) - # for i in range(data.shape[0]): - # self.parent_channels[i].send(("step", (data, i))) self._wait_for_workers(workers_range) @@ -2011,6 +2062,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: if self.event is not None: self.event.record() self.event.synchronize() + for i in workers_range: self.parent_channels[i].send(("step", data[i])) @@ -2076,11 +2128,13 @@ def _reset_no_buffers( needs_resetting, ) -> tuple[TensorDictBase, TensorDictBase]: if is_tensor_collection(tensordict): - # tensordict = tensordict.consolidate(share_memory=True, num_threads=1) if self.consolidate: try: tensordict = tensordict.consolidate( - share_memory=True, num_threads=1 + # share_memory=False: avoid resource_sharer which causes + # progressive slowdown with fork on Linux + share_memory=False, + num_threads=1, ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err @@ -2458,12 +2512,15 @@ def look_for_cuda(tensor, has_cuda=has_cuda): event.synchronize() if _non_tensor_keys: + # Set event BEFORE sending to avoid deadlocks when the pipe buffer + # is full (the parent will start reading as soon as it observes + # the event). + mp_event.set() child_pipe.send( ("non_tensor", cur_td.select(*_non_tensor_keys, strict=False)) ) - - # Set event only after non-tensor data is sent to avoid race condition - mp_event.set() + else: + mp_event.set() del cur_td @@ -2494,14 +2551,16 @@ def look_for_cuda(tensor, has_cuda=has_cuda): # Make sure the root is updated root_shared_tensordict.update_(env._step_mdp(input)) - # Set event before sending non-tensor data so parent knows worker is done - # The recv() call itself will provide synchronization for the pipe - mp_event.set() - if _non_tensor_keys: + # Set event BEFORE sending to avoid deadlocks when the pipe buffer + # is full (the parent will start reading as soon as it observes + # the event). + mp_event.set() child_pipe.send( ("non_tensor", next_td.select(*_non_tensor_keys, strict=False)) ) + else: + mp_event.set() del next_td @@ -2536,14 +2595,16 @@ def look_for_cuda(tensor, has_cuda=has_cuda): event.record() event.synchronize() - # Set event before sending non-tensor data so parent knows worker is done - # The recv() call itself will provide synchronization for the pipe - mp_event.set() - if _non_tensor_keys: ntd = root_next_td.select(*_non_tensor_keys) ntd.set("next", td_next.select(*_non_tensor_keys)) + # Set event BEFORE sending to avoid deadlocks when the pipe buffer + # is full (the parent will start reading as soon as it observes + # the event). + mp_event.set() child_pipe.send(("non_tensor", ntd)) + else: + mp_event.set() del td, root_next_td @@ -2686,8 +2747,6 @@ def _run_worker_pipe_direct( raise RuntimeError("call 'init' before resetting") # we use 'data' to pass the keys that we need to pass to reset, # because passing the entire buffer may have unwanted consequences - # data, idx, reset_kwargs = data - # data = data[idx] data, reset_kwargs = data if data is not None: data.unlock_() @@ -2703,18 +2762,19 @@ def _run_worker_pipe_direct( event.synchronize() if consolidate: try: - child_pipe.send( - cur_td.consolidate( - share_memory=True, inplace=True, num_threads=1 - ) + cur_td = cur_td.consolidate( + # share_memory=False: avoid resource_sharer which causes + # progressive slowdown with fork on Linux + share_memory=False, + inplace=True, + num_threads=1, ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err - else: - child_pipe.send(cur_td) - # Set event after successfully sending through pipe to avoid race condition - # where event is set but pipe send fails (BrokenPipeError) + # Set event BEFORE send so parent starts reading, which unblocks send + # if pipe buffer was full (prevents deadlock) mp_event.set() + child_pipe.send(cur_td) del cur_td @@ -2722,8 +2782,6 @@ def _run_worker_pipe_direct( if not initialized: raise RuntimeError("called 'init' before step") i += 1 - # data, idx = data - # data = data[idx] next_td = env._step(data) if event is not None: event.record() @@ -2731,14 +2789,18 @@ def _run_worker_pipe_direct( if consolidate: try: next_td = next_td.consolidate( - share_memory=True, inplace=True, num_threads=1 + # share_memory=False: avoid resource_sharer which causes + # progressive slowdown with fork on Linux + share_memory=False, + inplace=True, + num_threads=1, ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err - child_pipe.send(next_td) - # Set event after successfully sending through pipe to avoid race condition - # where event is set but pipe send fails (BrokenPipeError) + # Set event BEFORE send so parent starts reading, which unblocks send + # if pipe buffer was full (prevents deadlock) mp_event.set() + child_pipe.send(next_td) del next_td diff --git a/torchrl/testing/dist_utils.py b/torchrl/testing/dist_utils.py new file mode 100644 index 00000000000..dd3e69abb46 --- /dev/null +++ b/torchrl/testing/dist_utils.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import os +import time +from collections.abc import Callable +from typing import Any + +import psutil + +__all__ = [ + "assert_no_new_python_processes", + "is_python_process", + "snapshot_python_processes", +] + + +def is_python_process(comm: str | None, args: str | None) -> bool: + """Check if a process is a python process.""" + if comm is None: + comm = "" + comm = comm.lower() + if comm.startswith(("python", "pypy")): + return True + if not args: + return False + return "python" in args.lower() + + +def snapshot_python_processes( + root: psutil.Process | None = None, +) -> dict[tuple[int, float], dict[str, Any]]: + """Snapshot python processes belonging to the given process tree. + + Returns a dict keyed by (pid, start_time) -> info. + """ + if root is None: + root = psutil.Process(os.getpid()) + + uid = os.getuid() + + # Snapshot descendant PIDs first, then query process info via process_iter. + # This avoids race conditions where a child exits between `children()` and + # attribute access on a stale Process handle (common with Ray helpers). + descendant_pids = {root.pid} + descendant_pids.update(p.pid for p in root.children(recursive=True)) + + out: dict[tuple[int, float], dict[str, Any]] = {} + for proc in psutil.process_iter( + attrs=["pid", "name", "cmdline", "create_time", "uids"], ad_value=None + ): + info = proc.info + pid = info.get("pid") + if pid is None or pid not in descendant_pids: + continue + uids = info.get("uids") + if uids is None or uids.real != uid: + continue + + name = info.get("name") or "" + cmdline = info.get("cmdline") or [] + args = " ".join(cmdline) if isinstance(cmdline, (list, tuple)) else str(cmdline) + if not is_python_process(name, args): + continue + + start_time = float(info.get("create_time") or 0.0) + key = (int(pid), start_time) + out[key] = { + "pid": int(pid), + "start_time": start_time, + "comm": name, + "args": args, + } + return out + + +def assert_no_new_python_processes( + *, + baseline: dict[tuple[int, float], dict[str, Any]], + baseline_time: float, + timeout: float = 20.0, + ignore_info_fn: Callable[[dict[str, Any]], bool] | None = None, +) -> None: + """Assert that no python process started after baseline_time remains alive. + + The check is limited to the current process tree (pytest process + descendants). + """ + if ignore_info_fn is None: + + def ignore_info_fn(_info: dict[str, Any]) -> bool: + return False + + deadline = time.time() + timeout + last_new: dict[tuple[int, float], dict[str, Any]] | None = None + while time.time() < deadline: + current = snapshot_python_processes() + new: dict[tuple[int, float], dict[str, Any]] = {} + for (pid, start_time), info in current.items(): + if pid == os.getpid(): + continue + if ignore_info_fn(info): + continue + # Guard against pid reuse: only consider processes started after the baseline. + if start_time and start_time < baseline_time - 1.0: + continue + if (pid, start_time) in baseline: + continue + new[(pid, start_time)] = info + if not new: + return + last_new = new + time.sleep(0.25) + + if last_new is None: + return + details = "\n".join( + f"- pid={v['pid']} comm={v.get('comm')} args={v.get('args')}" + for v in last_new.values() + ) + raise AssertionError( + "Leaked python processes detected after collector.shutdown().\n" + f"Processes still alive:\n{details}" + ) diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index c9cad578c53..fd2625002c8 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -589,16 +589,13 @@ def _setup_connection_and_weights_on_receiver_impl( def shutdown(self) -> None: """Stop background receiver thread and clean up.""" - if self._stop_event is not None: - self._stop_event.set() - if self._background_thread is not None: - self._background_thread.join(timeout=5.0) - if self._background_thread.is_alive(): - torchrl_logger.warning( - "DistributedWeightSyncScheme: Background thread did not stop gracefully" - ) - self._background_thread = None - self._stop_event = None + # Check if already shutdown + if getattr(self, "_is_shutdown", False): + return + self._is_shutdown = True + + # Let base class handle background thread cleanup + super().shutdown() @property def model(self) -> Any | None: diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index cba335a967b..30f42544369 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -20,6 +20,12 @@ ) +def _close_mp_queue(queue: mp.Queue) -> None: + """Close a multiprocessing Queue and wait for its feeder thread to exit.""" + queue.close() + queue.join_thread() + + class SharedMemTransport: """Shared memory transport for in-place weight updates. @@ -886,22 +892,31 @@ def __getstate__(self): def shutdown(self) -> None: """Stop the background receiver thread and clean up.""" + # Check if already shutdown + if getattr(self, "_is_shutdown", False): + return + self._is_shutdown = True + # Signal all workers to stop - if self._instruction_queues: - for worker_idx in self._instruction_queues: - try: - self._instruction_queues[worker_idx].put("stop") - except Exception: - pass - - # Stop local background thread if running - if self._stop_event is not None: - self._stop_event.set() - if self._background_thread is not None: - self._background_thread.join(timeout=5.0) - if self._background_thread.is_alive(): - torchrl_logger.warning( - "SharedMemWeightSyncScheme: Background thread did not stop gracefully" - ) - self._background_thread = None - self._stop_event = None + instruction_queues = getattr(self, "_instruction_queues", None) + if instruction_queues: + for _, queue in instruction_queues.items(): + queue.put("stop") + + # Let base class handle background thread cleanup + super().shutdown() + + # Close all multiprocessing queues created by the scheme. + queues_to_close = [] + for name in ("_weight_init_queues", "_instruction_queues", "_ack_queues"): + mapping = getattr(self, name, None) + if not mapping: + continue + queues_to_close.extend(mapping.values()) + setattr(self, name, {}) + + unique = {} + for q in queues_to_close: + unique[id(q)] = q + for q in unique.values(): + _close_mp_queue(q) diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 75ab16563b4..b381a4db55b 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -1231,3 +1231,30 @@ def __getstate__(self): def __setstate__(self, state): """Restore the scheme from pickling.""" self.__dict__.update(state) + + def __del__(self): + """Clean up resources when the scheme is garbage collected.""" + try: + self.shutdown() + except Exception: + # Silently ignore any errors during garbage collection cleanup + pass + + def shutdown(self) -> None: + """Shutdown the scheme and release resources. + + This method stops any background threads and cleans up connections. + It is safe to call multiple times. Subclasses should override this + method to add custom cleanup logic, but should call super().shutdown() + to ensure base cleanup is performed. + """ + # Stop background receiver thread if running + if getattr(self, "_stop_event", None) is not None: + self._stop_event.set() + if getattr(self, "_background_thread", None) is not None: + try: + self._background_thread.join(timeout=5.0) + except Exception: + pass + self._background_thread = None + self._stop_event = None