Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ uv_pip_install \
"pybind11[global]>=2.13" \
pyyaml \
scipy \
psutil \
hydra-core \
tensorboard \
"imageio==2.26.0" \
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/linux_distributed/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- pybind11[global]
- pyyaml
- scipy
- psutil
- hydra-core
- tensorboard
- imageio==2.26.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ expecttest
pybind11[global]
pyyaml
scipy
psutil
hydra-core
minari[gcs,hdf5,hf,create]
gymnasium>=1.2.0
Expand Down
19 changes: 10 additions & 9 deletions .github/unittest/linux_olddeps/scripts_gym_0_13/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,25 @@ printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [ "${CU_VERSION:-}" == cpu ] ; then
conda install pytorch==2.1 torchvision==0.16 cpuonly -c pytorch -y
else
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
python -m pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
# conda install pytorch==2.1 torchvision==0.16 pytorch-cuda=11.8 "numpy<2.0" -c pytorch -c nvidia -y
fi

# Solving circular import: https://stackoverflow.com/questions/75501048/how-to-fix-attributeerror-partially-initialized-module-charset-normalizer-has
#pip install -U charset-normalizer

# install tensordict
if [[ "$RELEASE" == 0 ]]; then
conda install anaconda::cmake -y
python -m pip install "pybind11[global]"
python -m pip install git+https://github.com/pytorch/tensordict.git
else
python -m pip install tensordict
fi
#
# NOTE:
# - The olddeps CI job runs on older Python/torch stacks.
# - Installing from tensordict `main` (git+https) is brittle as `main` may drop
# support for older Python versions at any time, which can lead to "tensordict
# not installed" failures in downstream smoke tests.
# - Use the same (pinned) range as TorchRL itself to keep this job stable.
python -m pip install "${TORCHRL_TENSORDICT_SPEC:-tensordict>=0.10.0,<0.11.0}"

# smoke test
python -c "import tensordict"
python -c "import tensordict; print(f'tensordict: {tensordict.__version__}')"

printf "* Installing torchrl\n"
python -m pip install -e . --no-build-isolation
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/linux_optdeps/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ dependencies:
- pybind11[global]
- pyyaml
- scipy
- psutil
- coverage
- ray
1 change: 1 addition & 0 deletions .github/unittest/linux_sota/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- pybind11[global]
- pyyaml
- scipy
- psutil
- hydra-core
- imageio==2.26.0
- dm_control
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/linux_sota/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ uv pip install \
pybind11 \
pyyaml \
scipy \
psutil \
hydra-core \
"imageio==2.26.0" \
dm_control \
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/windows_optdepts/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ dependencies:
- expecttest
- pyyaml
- scipy
- psutil
- coverage
3 changes: 3 additions & 0 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ jobs:
set -euo pipefail
export PYTHON_VERSION="3.9"
export CU_VERSION="cu118"
# Olddeps runs on Python 3.9: pin tensordict to a Python-3.9-compatible range.
# (Avoid installing tensordict from git main, which may drop older Python support.)
export TORCHRL_TENSORDICT_SPEC="tensordict>=0.10.0,<0.11.0"
export TAR_OPTIONS="--no-same-owner"
if [[ "${{ github.ref }}" =~ release/* ]]; then
export RELEASE=1
Expand Down
29 changes: 29 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# TorchRL Development Makefile

.PHONY: clean build develop test

# Clean all build artifacts (use when switching Python/PyTorch versions)
clean:
rm -rf build/
rm -rf dist/
rm -rf *.egg-info/
rm -rf torchrl/*.egg-info/
rm -f torchrl/_torchrl*.so
rm -f torchrl/version.py
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true

# Build C++ extensions in-place
build:
python setup.py build_ext --inplace

# Full clean + build
rebuild: clean build

# Development install (editable)
develop: rebuild
pip install -e . --no-build-isolation

# Run tests
test:
python -m pytest test/ -v --timeout 60
2 changes: 1 addition & 1 deletion benchmarks/test_non_tensor_env_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_non_tensor_env_rollout_speed(
):
"""Benchmarks a single rollout, after a warmup rollout, for non-tensor stacking envs.

Mirrors `test/test_env.py::TestNonTensorEnv`'s option matrix (single/serial/parallel,
Mirrors `test/test_envs.py::TestNonTensorEnv`'s option matrix (single/serial/parallel,
break_when_any_done, use_buffers).
"""
with set_capture_non_tensor_stack(False):
Expand Down
68 changes: 68 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,85 @@
from __future__ import annotations

import contextlib
import glob
import importlib.util
import json
import logging
import os
import re
import subprocess
import sys
from pathlib import Path

import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

logger = logging.getLogger(__name__)

ROOT_DIR = Path(__file__).parent.resolve()
_RELEASE_BRANCH_RE = re.compile(r"^release/v(?P<release_id>.+)$")
_BUILD_INFO_FILE = ROOT_DIR / "build" / ".torchrl_build_info.json"


def _check_and_clean_stale_builds():
"""Check if existing build was made with a different PyTorch version and clean if so.

This prevents ABI incompatibility issues when switching between PyTorch versions.
"""
current_torch_version = torch.__version__
current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"

if _BUILD_INFO_FILE.exists():
try:
with open(_BUILD_INFO_FILE) as f:
build_info = json.load(f)
old_torch = build_info.get("torch_version")
old_python = build_info.get("python_version")

if (
old_torch != current_torch_version
or old_python != current_python_version
):
logger.warning(
f"Detected PyTorch/Python version change: "
f"PyTorch {old_torch} -> {current_torch_version}, "
f"Python {old_python} -> {current_python_version}. "
f"Cleaning stale build artifacts..."
)
# Clean stale .so files for current Python version
so_pattern = (
ROOT_DIR
/ "torchrl"
/ f"_torchrl.cpython-{sys.version_info.major}{sys.version_info.minor}*.so"
)
for so_file in glob.glob(str(so_pattern)):
logger.warning(f"Removing stale: {so_file}")
os.remove(so_file)
# Clean build directory
build_dir = ROOT_DIR / "build"
if build_dir.exists():
import shutil

for item in build_dir.iterdir():
if item.name.startswith("temp.") or item.name.startswith(
"lib."
):
logger.warning(f"Removing stale build dir: {item}")
shutil.rmtree(item)
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Could not read build info: {e}")

# Write current build info
_BUILD_INFO_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(_BUILD_INFO_FILE, "w") as f:
json.dump(
{
"torch_version": current_torch_version,
"python_version": current_python_version,
},
f,
)


def get_extensions():
Expand Down Expand Up @@ -162,6 +227,9 @@ def set_version():

def main():
"""Main setup function for building TorchRL with C++ extensions."""
# Check for stale builds from different PyTorch/Python versions
_check_and_clean_stale_builds()

with set_version():
pretend_version = os.environ.get("SETUPTOOLS_SCM_PRETEND_VERSION")
_has_setuptools_scm = importlib.util.find_spec("setuptools_scm") is not None
Expand Down
5 changes: 4 additions & 1 deletion test/llm/libs/test_mlgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import importlib.util

from functools import partial

Expand All @@ -16,7 +17,9 @@
from torchrl.envs.llm import make_mlgym
from torchrl.modules.llm import TransformersWrapper

pytest.importorskip("mlgym")
pytestmark = pytest.mark.skipif(
not importlib.util.find_spec("mlgym"), reason="mlgym not available"
)


class TestMLGYM:
Expand Down
File renamed without changes.
10 changes: 9 additions & 1 deletion test/llm/test_envs.py → test/llm/test_llm_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
)

from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from transformers import AutoTokenizer

_has_ray = importlib.util.find_spec("ray") is not None
_has_transformers = importlib.util.find_spec("transformers") is not None
Expand All @@ -43,6 +42,11 @@
and (importlib.util.find_spec("immutabledict") is not None)
)

pytestmark = pytest.mark.skipif(
not (_has_datasets & _has_transformers & _has_vllm & _has_ray),
reason="requires datasets, transformers, vllm, and ray",
)


@pytest.fixture(scope="module", autouse=True)
def set_seed():
Expand Down Expand Up @@ -75,6 +79,8 @@ def set_list_to_stack_for_test():
class TestChatEnv:
@pytest.fixture
def tokenizer(self):
from transformers import AutoTokenizer

return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")

@pytest.mark.parametrize("input_mode", ["text", "tokens", "history"])
Expand Down Expand Up @@ -789,6 +795,7 @@ def delayed_calculator(cls, operation: str, a: float, b: float) -> dict:
@classmethod
def make_env(cls):
from torchrl.envs.llm.transforms.tools import SimpleToolTransform
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
env = ChatEnv(
Expand Down Expand Up @@ -871,6 +878,7 @@ def test_async_mcp_tools(self):
def test_mcp_python_execution(self):
"""Test actual MCP Python execution with mcp-run-python server."""
from torchrl.envs.llm.transforms import MCPToolTransform
from transformers import AutoTokenizer

# Setup environment for MCP (Deno needs to be in PATH)
environ = os.environ.copy()
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 13 additions & 8 deletions test/llm/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import gc
import importlib.util
import threading

import time
from concurrent.futures import ThreadPoolExecutor, wait
from functools import partial
from typing import Any, TYPE_CHECKING

import pytest
import torch
Expand All @@ -31,18 +31,21 @@
)
from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
from torchrl.modules.llm.policies.vllm_wrapper import vLLMWrapper
from transformers import AutoTokenizer


_has_transformers = importlib.util.find_spec("transformers") is not None
_has_vllm = importlib.util.find_spec("vllm") is not None
_has_datasets = importlib.util.find_spec("datasets") is not None
_has_ray = importlib.util.find_spec("ray") is not None
# _has_datasets = importlib.util.find_spec("datasets") is not None

TransformersWrapperMaxTokens = partial(
TransformersWrapper, generate_kwargs={"max_new_tokens": 10, "do_sample": True}
)

if TYPE_CHECKING:
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM


@pytest.fixture(scope="function", autouse=True)
def set_seed():
Expand All @@ -62,9 +65,7 @@ def set_list_to_stack_fixture():


@pytest.fixture(scope="module")
def vllm_instance() -> tuple[
vllm.LLM, transformers.AutoTokenizer # noqa # type: ignore
]: # noqa # type: ignore
def vllm_instance() -> tuple[LLM, AutoTokenizer]: # noqa # type: ignore
"""Create vLLM model and tokenizer for testing."""
if not _has_vllm:
pytest.skip("vllm not available")
Expand All @@ -83,6 +84,8 @@ def vllm_instance() -> tuple[
max_model_len=32768,
gpu_memory_utilization=0.3, # Limit to 30% GPU memory to avoid OOM with multiple engines
)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
Expand All @@ -92,7 +95,7 @@ def vllm_instance() -> tuple[

@pytest.fixture(scope="module")
def async_vllm_instance() -> tuple[
Any, transformers.AutoTokenizer # noqa # type: ignore
Any, AutoTokenizer # noqa # type: ignore
]: # noqa # type: ignore
"""Create async vLLM engine and tokenizer for testing."""
if not _has_vllm:
Expand All @@ -114,6 +117,8 @@ def async_vllm_instance() -> tuple[
max_num_batched_tokens=32768,
gpu_memory_utilization=0.3, # Limit to 30% GPU memory to avoid OOM with multiple engines
)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
tokenizer.pad_token = tokenizer.eos_token
return async_engine, tokenizer
Expand All @@ -123,7 +128,7 @@ def async_vllm_instance() -> tuple[

@pytest.fixture(scope="module")
def transformers_instance() -> tuple[
transformers.AutoModelForCausalLM, transformers.AutoTokenizer # noqa # type: ignore
AutoModelForCausalLM, AutoTokenizer # noqa # type: ignore
]: # noqa # type: ignore
"""Create transformers model and tokenizer for testing."""
if not _has_transformers:
Expand Down
Loading
Loading