Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions .github/unittest/linux/scripts/run_setup_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env bash

set -euxo pipefail

if [[ $OSTYPE != 'darwin'* ]]; then
export DEBIAN_FRONTEND=noninteractive
export TZ="${TZ:-Etc/UTC}"
ln -snf "/usr/share/zoneinfo/${TZ}" /etc/localtime || true
echo "${TZ}" > /etc/timezone || true

apt-get update
apt-get install -y --no-install-recommends tzdata
dpkg-reconfigure -f noninteractive tzdata || true

apt-get upgrade -y
apt-get install -y git wget cmake curl python3-dev g++ gcc
fi

# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
env_dir="${root_dir}/venv-setup-test"

cd "${root_dir}"

# Install uv (used for --no-deps install path parity with CI)
curl -LsSf https://astral.sh/uv/install.sh | sh
export PATH="$HOME/.local/bin:$PATH"

rm -rf "${env_dir}"
uv venv --python "${PYTHON_VERSION}" "${env_dir}"
source "${env_dir}/bin/activate"

uv_pip_install() {
uv pip install --no-progress --python "${env_dir}/bin/python" "$@"
}

python -c "import sys; print(sys.version)"

# Ensure `python -m pip` exists (uv-created venvs may not include pip).
python -m ensurepip --upgrade

# Minimal runtime/build deps + pytest only.
uv_pip_install \
pytest \
setuptools \
wheel \
packaging \
cloudpickle \
pyvers \
numpy \
ninja \
"pybind11[global]>=2.13"

ref_name="${GITHUB_REF_NAME:-}"
if [[ -z "${ref_name}" && -n "${GITHUB_REF:-}" ]]; then
ref_name="${GITHUB_REF#refs/heads/}"
fi

if [[ "${ref_name}" == release/* ]]; then
export RELEASE=1
export TORCH_VERSION=stable
else
export RELEASE=0
export TORCH_VERSION=nightly
fi

if [[ "$TORCH_VERSION" == "nightly" ]]; then
uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
else
uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu
fi

# tensordict is a hard dependency of torchrl; install it explicitly since we test
# `pip/uv install --no-deps` for torchrl itself.
if [[ "$RELEASE" == 0 ]]; then
uv_pip_install --no-build-isolation --no-deps git+https://github.com/pytorch/tensordict.git
else
uv_pip_install tensordict
fi

pytest -q test/test_setup.py -vv
2 changes: 2 additions & 0 deletions .github/unittest/linux_sota/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ uv pip install \
hypothesis \
future \
cloudpickle \
pyvers \
packaging \
pygame \
"moviepy<2.0.0" \
tqdm \
Expand Down
7 changes: 6 additions & 1 deletion .github/workflows/test-linux-sota.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ jobs:
export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }}
export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"

if [[ "${{ github.ref }}" =~ release/* ]]; then
ref_name="${GITHUB_REF_NAME:-}"
if [[ -z "${ref_name}" && -n "${GITHUB_REF:-}" ]]; then
ref_name="${GITHUB_REF#refs/heads/}"
fi

if [[ "${ref_name}" == release/* ]]; then
export RELEASE=1
export TORCH_VERSION=stable
else
Expand Down
19 changes: 19 additions & 0 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ permissions:
contents: read

jobs:
test-setup-minimal:
strategy:
matrix:
python_version: ["3.9", "3.14"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.4xlarge
repository: pytorch/rl
docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04"
timeout: 90
script: |
set -euo pipefail
export PYTHON_VERSION=${{ matrix.python_version }}
export CU_VERSION="cpu"
echo "PYTHON_VERSION: $PYTHON_VERSION"
echo "CU_VERSION: $CU_VERSION"
bash .github/unittest/linux/scripts/run_setup_test.sh

tests-cpu:
strategy:
matrix:
Expand Down
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include torchrl/version.py
include torchrl/version.py
include version.txt
20 changes: 17 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
[build-system]
requires = ["setuptools", "wheel", "torch", "ninja", "numpy", "pybind11[global]", "cmake"]
requires = [
"setuptools",
"wheel",
"setuptools_scm",
"torch",
"ninja",
"numpy",
"pybind11[global]",
"cmake",
]
build-backend = "setuptools.build_meta"

[project]
Expand Down Expand Up @@ -130,8 +139,13 @@ linkedin = "https://www.linkedin.com/company/torchrl"
discord = "https://discord.gg/cZs26Qq3Dd"
benchmark = "https://docs.pytorch.org/rl/dev/bench/"

[tool.setuptools.dynamic]
version = {file = "version.txt"}
[tool.setuptools_scm]
# Use SETUPTOOLS_SCM_PRETEND_VERSION=M.Major.Minor to set the version for stable releases.
version_scheme = "post-release"
# Local scheme is handled by setup.py (appends +g<sha> unless on release/v<version> branch)
local_scheme = "no-local-version"
version_file = "torchrl/_version.py"
fallback_version = "0.10.0"

[tool.setuptools.packages.find]
exclude = [
Expand Down
98 changes: 86 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import contextlib
import glob
import importlib.util
import logging
import os
import re
import subprocess
import sys
from pathlib import Path

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

logger = logging.getLogger(__name__)

ROOT_DIR = Path(__file__).parent.resolve()
_RELEASE_BRANCH_RE = re.compile(r"^release/v(?P<release_id>.+)$")


def get_extensions():
"""Build C++ extensions with platform-specific compiler flags.
Expand Down Expand Up @@ -91,20 +99,86 @@ def get_extensions():
return ext_modules


def _git_output(args) -> str | None:
try:
return (
subprocess.check_output(["git", *args], cwd=str(ROOT_DIR))
.decode("utf-8")
.strip()
)
except Exception:
return None


def _branch_name() -> str | None:
for key in (
"GITHUB_REF_NAME",
"GIT_BRANCH",
"BRANCH_NAME",
"CI_COMMIT_REF_NAME",
):
val = os.environ.get(key)
if val:
return val
branch = _git_output(["rev-parse", "--abbrev-ref", "HEAD"])
if not branch or branch == "HEAD":
return None
return branch


def _short_sha() -> str | None:
return _git_output(["rev-parse", "--short", "HEAD"])


def _version_with_local_sha(base_version: str) -> str:
# Do not append local version on the matching release branch.
branch = _branch_name()
if branch:
m = _RELEASE_BRANCH_RE.match(branch)
if m and m.group("release_id").strip() == base_version.strip():
return base_version
sha = _short_sha()
if not sha:
return base_version
return f"{base_version}+g{sha}"


@contextlib.contextmanager
def set_version():
# Prefer explicit build version if provided by build tooling.
if "SETUPTOOLS_SCM_PRETEND_VERSION" not in os.environ:
override = os.environ.get("TORCHRL_BUILD_VERSION")
if override:
os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = override.strip()
else:
base_version = (ROOT_DIR / "version.txt").read_text().strip()
full_version = _version_with_local_sha(base_version)
os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = full_version
yield
del os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"]
return
yield


def main():
"""Main setup function for building TorchRL with C++ extensions."""
setup_kwargs = {
"ext_modules": get_extensions(),
"cmdclass": {"build_ext": BuildExtension.with_options()},
"packages": ["torchrl"],
"package_data": {
"torchrl": ["version.py"],
},
"include_package_data": True,
"zip_safe": False,
}

setup(**setup_kwargs)
with set_version():
pretend_version = os.environ.get("SETUPTOOLS_SCM_PRETEND_VERSION")
_has_setuptools_scm = importlib.util.find_spec("setuptools_scm") is not None

setup_kwargs = {
"ext_modules": get_extensions(),
"cmdclass": {"build_ext": BuildExtension.with_options()},
"zip_safe": False,
**(
{"setup_requires": ["setuptools_scm"], "use_scm_version": True}
if _has_setuptools_scm
# pretend_version already includes +g<sha> (computed in set_version)
else {"version": pretend_version}
),
}

setup(**setup_kwargs)


if __name__ == "__main__":
Expand Down
Loading
Loading