diff --git a/.github/unittest/linux/scripts/run_setup_test.sh b/.github/unittest/linux/scripts/run_setup_test.sh new file mode 100644 index 00000000000..e95ed547a5a --- /dev/null +++ b/.github/unittest/linux/scripts/run_setup_test.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +if [[ $OSTYPE != 'darwin'* ]]; then + export DEBIAN_FRONTEND=noninteractive + export TZ="${TZ:-Etc/UTC}" + ln -snf "/usr/share/zoneinfo/${TZ}" /etc/localtime || true + echo "${TZ}" > /etc/timezone || true + + apt-get update + apt-get install -y --no-install-recommends tzdata + dpkg-reconfigure -f noninteractive tzdata || true + + apt-get upgrade -y + apt-get install -y git wget cmake curl python3-dev g++ gcc +fi + +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/venv-setup-test" + +cd "${root_dir}" + +# Install uv (used for --no-deps install path parity with CI) +curl -LsSf https://astral.sh/uv/install.sh | sh +export PATH="$HOME/.local/bin:$PATH" + +rm -rf "${env_dir}" +uv venv --python "${PYTHON_VERSION}" "${env_dir}" +source "${env_dir}/bin/activate" + +uv_pip_install() { + uv pip install --no-progress --python "${env_dir}/bin/python" "$@" +} + +python -c "import sys; print(sys.version)" + +# Ensure `python -m pip` exists (uv-created venvs may not include pip). +python -m ensurepip --upgrade + +# Minimal runtime/build deps + pytest only. +uv_pip_install \ + pytest \ + setuptools \ + wheel \ + packaging \ + cloudpickle \ + pyvers \ + numpy \ + ninja \ + "pybind11[global]>=2.13" + +ref_name="${GITHUB_REF_NAME:-}" +if [[ -z "${ref_name}" && -n "${GITHUB_REF:-}" ]]; then + ref_name="${GITHUB_REF#refs/heads/}" +fi + +if [[ "${ref_name}" == release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable +else + export RELEASE=0 + export TORCH_VERSION=nightly +fi + +if [[ "$TORCH_VERSION" == "nightly" ]]; then + uv_pip_install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +else + uv_pip_install torch torchvision --index-url https://download.pytorch.org/whl/cpu +fi + +# tensordict is a hard dependency of torchrl; install it explicitly since we test +# `pip/uv install --no-deps` for torchrl itself. +if [[ "$RELEASE" == 0 ]]; then + uv_pip_install --no-build-isolation --no-deps git+https://github.com/pytorch/tensordict.git +else + uv_pip_install tensordict +fi + +pytest -q test/test_setup.py -vv diff --git a/.github/unittest/linux_sota/scripts/run_all.sh b/.github/unittest/linux_sota/scripts/run_all.sh index 226fe49f6a8..4c8147ecbd1 100755 --- a/.github/unittest/linux_sota/scripts/run_all.sh +++ b/.github/unittest/linux_sota/scripts/run_all.sh @@ -70,6 +70,8 @@ uv pip install \ hypothesis \ future \ cloudpickle \ + pyvers \ + packaging \ pygame \ "moviepy<2.0.0" \ tqdm \ diff --git a/.github/workflows/test-linux-sota.yml b/.github/workflows/test-linux-sota.yml index 7032631c9a6..82c0d4f95a3 100644 --- a/.github/workflows/test-linux-sota.yml +++ b/.github/workflows/test-linux-sota.yml @@ -43,7 +43,12 @@ jobs: export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }} export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}" - if [[ "${{ github.ref }}" =~ release/* ]]; then + ref_name="${GITHUB_REF_NAME:-}" + if [[ -z "${ref_name}" && -n "${GITHUB_REF:-}" ]]; then + ref_name="${GITHUB_REF#refs/heads/}" + fi + + if [[ "${ref_name}" == release/* ]]; then export RELEASE=1 export TORCH_VERSION=stable else diff --git a/.github/workflows/test-linux.yml b/.github/workflows/test-linux.yml index 14aaadb580f..6dbf9087e2e 100644 --- a/.github/workflows/test-linux.yml +++ b/.github/workflows/test-linux.yml @@ -23,6 +23,25 @@ permissions: contents: read jobs: + test-setup-minimal: + strategy: + matrix: + python_version: ["3.9", "3.14"] + fail-fast: false + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.4xlarge + repository: pytorch/rl + docker-image: "nvidia/cuda:13.0.2-cudnn-devel-ubuntu24.04" + timeout: 90 + script: | + set -euo pipefail + export PYTHON_VERSION=${{ matrix.python_version }} + export CU_VERSION="cpu" + echo "PYTHON_VERSION: $PYTHON_VERSION" + echo "CU_VERSION: $CU_VERSION" + bash .github/unittest/linux/scripts/run_setup_test.sh + tests-cpu: strategy: matrix: diff --git a/MANIFEST.in b/MANIFEST.in index 34d5a6cf2b5..16555347d6a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ -include torchrl/version.py +include torchrl/version.py +include version.txt diff --git a/pyproject.toml b/pyproject.toml index 7d1de1b5745..32e18e00b50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,14 @@ [build-system] -requires = ["setuptools", "wheel", "torch", "ninja", "numpy", "pybind11[global]", "cmake"] +requires = [ + "setuptools", + "wheel", + "setuptools_scm", + "torch", + "ninja", + "numpy", + "pybind11[global]", + "cmake", +] build-backend = "setuptools.build_meta" [project] @@ -130,8 +139,13 @@ linkedin = "https://www.linkedin.com/company/torchrl" discord = "https://discord.gg/cZs26Qq3Dd" benchmark = "https://docs.pytorch.org/rl/dev/bench/" -[tool.setuptools.dynamic] -version = {file = "version.txt"} +[tool.setuptools_scm] +# Use SETUPTOOLS_SCM_PRETEND_VERSION=M.Major.Minor to set the version for stable releases. +version_scheme = "post-release" +# Local scheme is handled by setup.py (appends +g unless on release/v branch) +local_scheme = "no-local-version" +version_file = "torchrl/_version.py" +fallback_version = "0.10.0" [tool.setuptools.packages.find] exclude = [ diff --git a/setup.py b/setup.py index 21bd6f9f3d9..b63a367edba 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,21 @@ +import contextlib import glob +import importlib.util import logging import os +import re +import subprocess import sys +from pathlib import Path from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension logger = logging.getLogger(__name__) +ROOT_DIR = Path(__file__).parent.resolve() +_RELEASE_BRANCH_RE = re.compile(r"^release/v(?P.+)$") + def get_extensions(): """Build C++ extensions with platform-specific compiler flags. @@ -91,20 +99,86 @@ def get_extensions(): return ext_modules +def _git_output(args) -> str | None: + try: + return ( + subprocess.check_output(["git", *args], cwd=str(ROOT_DIR)) + .decode("utf-8") + .strip() + ) + except Exception: + return None + + +def _branch_name() -> str | None: + for key in ( + "GITHUB_REF_NAME", + "GIT_BRANCH", + "BRANCH_NAME", + "CI_COMMIT_REF_NAME", + ): + val = os.environ.get(key) + if val: + return val + branch = _git_output(["rev-parse", "--abbrev-ref", "HEAD"]) + if not branch or branch == "HEAD": + return None + return branch + + +def _short_sha() -> str | None: + return _git_output(["rev-parse", "--short", "HEAD"]) + + +def _version_with_local_sha(base_version: str) -> str: + # Do not append local version on the matching release branch. + branch = _branch_name() + if branch: + m = _RELEASE_BRANCH_RE.match(branch) + if m and m.group("release_id").strip() == base_version.strip(): + return base_version + sha = _short_sha() + if not sha: + return base_version + return f"{base_version}+g{sha}" + + +@contextlib.contextmanager +def set_version(): + # Prefer explicit build version if provided by build tooling. + if "SETUPTOOLS_SCM_PRETEND_VERSION" not in os.environ: + override = os.environ.get("TORCHRL_BUILD_VERSION") + if override: + os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = override.strip() + else: + base_version = (ROOT_DIR / "version.txt").read_text().strip() + full_version = _version_with_local_sha(base_version) + os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] = full_version + yield + del os.environ["SETUPTOOLS_SCM_PRETEND_VERSION"] + return + yield + + def main(): """Main setup function for building TorchRL with C++ extensions.""" - setup_kwargs = { - "ext_modules": get_extensions(), - "cmdclass": {"build_ext": BuildExtension.with_options()}, - "packages": ["torchrl"], - "package_data": { - "torchrl": ["version.py"], - }, - "include_package_data": True, - "zip_safe": False, - } - - setup(**setup_kwargs) + with set_version(): + pretend_version = os.environ.get("SETUPTOOLS_SCM_PRETEND_VERSION") + _has_setuptools_scm = importlib.util.find_spec("setuptools_scm") is not None + + setup_kwargs = { + "ext_modules": get_extensions(), + "cmdclass": {"build_ext": BuildExtension.with_options()}, + "zip_safe": False, + **( + {"setup_requires": ["setuptools_scm"], "use_scm_version": True} + if _has_setuptools_scm + # pretend_version already includes +g (computed in set_version) + else {"version": pretend_version} + ), + } + + setup(**setup_kwargs) if __name__ == "__main__": diff --git a/test/test_setup.py b/test/test_setup.py new file mode 100644 index 00000000000..e395b6682c6 --- /dev/null +++ b/test/test_setup.py @@ -0,0 +1,151 @@ +import json +import os +import shutil +import subprocess +import sys +from pathlib import Path + +import pytest + +_ROOT = Path(__file__).resolve().parents[1] + + +def _run(cmd, *, cwd, env=None, timeout=60 * 60) -> str: + proc = subprocess.run( + cmd, + cwd=str(cwd), + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=timeout, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError( + "Command failed.\n" + f"cwd: {cwd}\n" + f"cmd: {cmd}\n" + f"exit_code: {proc.returncode}\n" + f"output:\n{proc.stdout}" + ) + return proc.stdout + + +def _pip_uninstall(pkg: str) -> None: + subprocess.run( + [sys.executable, "-m", "pip", "uninstall", "-y", pkg], + cwd=str(_ROOT), + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=False, + ) + + +def _install_cmd_prefix(installer: str) -> list[str]: + if installer == "uv": + return ["uv", "pip", "install", "--python", sys.executable] + return [sys.executable, "-m", "pip", "install"] + + +def _git(args) -> str: + return _run(["git", *args], cwd=_ROOT, timeout=60).strip() + + +def _expected_dist_version(base_version: str) -> str: + # Match the same release-branch rule as setup.py. + branch = None + for key in ("GITHUB_REF_NAME", "GIT_BRANCH", "BRANCH_NAME", "CI_COMMIT_REF_NAME"): + val = os.environ.get(key) + if val: + branch = val + break + if branch is None: + b = _git(["rev-parse", "--abbrev-ref", "HEAD"]) + branch = None if b == "HEAD" else b + if branch is not None and branch.startswith("refs/heads/"): + branch = branch[len("refs/heads/") :] + + if branch is not None and ( + branch == f"release/v{base_version}" + or branch.endswith(f"/release/v{base_version}") + ): + return base_version + + return f"{base_version}+g{_git(['rev-parse', '--short', 'HEAD'])}" + + +@pytest.mark.parametrize("editable", [True, False], ids=["editable", "wheel"]) +def test_install_no_deps_has_nonzero_version(editable: bool, tmp_path: Path): + # Requires git checkout. + if not (_ROOT / ".git").exists(): + pytest.skip("not a git checkout") + + base_version = (_ROOT / "version.txt").read_text().strip() + if not base_version: + raise RuntimeError("Empty version.txt") + + expected = _expected_dist_version(base_version) + + installer = "uv" if shutil.which("uv") is not None else "pip" + + # Ensure we cover the historical failure mode where version becomes 0.0.0 when + # build requirements aren't present (e.g. --no-build-isolation). This dedicated + # CI job intentionally runs with --no-deps installs. + _pip_uninstall("setuptools_scm") + + # Ensure clean re-install for each case. + _pip_uninstall("torchrl") + + cmd = _install_cmd_prefix(installer) + cmd.append("--no-deps") + cmd.append("--no-build-isolation") + if editable: + cmd.extend(["-e", "."]) + else: + cmd.append(".") + + _run(cmd, cwd=_ROOT, timeout=60 * 60) + + probe_dir = tmp_path / "probe" + probe_dir.mkdir(parents=True, exist_ok=True) + + code = r""" +import importlib.metadata as md +import json + +out = {} +out["dist_version"] = md.version("torchrl") +try: + import torchrl + out["pkg_version"] = getattr(torchrl, "__version__", None) + out["pkg_file"] = getattr(torchrl, "__file__", None) +except Exception as err: + out["import_error"] = repr(err) + +print(json.dumps(out)) +""" + out = _run([sys.executable, "-c", code], cwd=probe_dir, timeout=5 * 60) + info = json.loads(out.strip()) + + dist_version = str(info["dist_version"]).strip() + assert dist_version != "0.0.0" + assert dist_version == expected + + pkg_version = info.get("pkg_version") + pkg_file = info.get("pkg_file") + if pkg_version is not None and pkg_file is not None: + pkg_version = str(pkg_version).strip() + assert pkg_version != "0.0.0" + assert pkg_version == expected + + pkg_path = Path(pkg_file).resolve() + if editable: + assert str(pkg_path).startswith(str(_ROOT.resolve())) + else: + assert "site-packages" in str(pkg_path) + else: + # If some hard dependency is missing, import can fail. The packaging version + # should still be correct. + assert "dist_version" in info diff --git a/torchrl/__init__.py b/torchrl/__init__.py index 2cc1222b49e..b7fc2aa74f6 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -23,14 +23,22 @@ from ._extension import _init_extension # noqa: E402 - +__version__ = None # type: ignore try: - from .version import __version__ -except ImportError: + try: + from importlib.metadata import version as _dist_version + except ImportError: # pragma: no cover + from importlib_metadata import version as _dist_version # type: ignore + + __version__ = _dist_version("torchrl") +except Exception: try: from ._version import __version__ - except ImportError: - __version__ = "0.0.0+unknown" + except Exception: + try: + from .version import __version__ + except Exception: + __version__ = None # type: ignore try: from torch.compiler import is_dynamo_compiling diff --git a/torchrl/_extension.py b/torchrl/_extension.py index f95ae9b8a88..e1f359ae271 100644 --- a/torchrl/_extension.py +++ b/torchrl/_extension.py @@ -9,10 +9,20 @@ from packaging.version import parse +__version__ = None # type: ignore try: - from .version import __version__, pytorch_version + try: + from importlib.metadata import version as _dist_version + except ImportError: # pragma: no cover + from importlib_metadata import version as _dist_version # type: ignore + + __version__ = _dist_version("torchrl") +except Exception: + __version__ = None # type: ignore + +try: + from .version import pytorch_version except ImportError: - __version__ = None pytorch_version = "unknown" diff --git a/torchrl/modules/llm/__init__.py b/torchrl/modules/llm/__init__.py index 0ee9f37375d..bd5d37067eb 100644 --- a/torchrl/modules/llm/__init__.py +++ b/torchrl/modules/llm/__init__.py @@ -2,27 +2,25 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""LLM utilities for TorchRL. + +Note: + This package contains optional integrations (e.g. vLLM) that may rely on native + extensions. To keep `import torchrl` / `import torchrl.envs` lightweight and + robust, we **avoid importing optional backends at module import time** and + instead only import those backends on demand. +""" + from __future__ import annotations -from .backends import ( - AsyncVLLM, - make_async_vllm_engine, - make_vllm_worker, - stateless_init_process_group, - stateless_init_process_group_async, -) +from typing import Any -from .policies import ( - ChatHistory, - LLMWrapperBase, - LogProbs, - Masks, +from .policies.common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens +from .policies.transformers_wrapper import ( RemoteTransformersWrapper, - Text, - Tokens, TransformersWrapper, - vLLMWrapper, ) +from .policies.vllm_wrapper import vLLMWrapper __all__ = [ # Data structures @@ -46,3 +44,19 @@ "make_vllm_worker", "stateless_init_process_group", ] + + +def __getattr__(name: str) -> Any: # noqa: ANN401 + # Keep backends optional and on-demand to avoid importing vLLM native extensions + # as a side-effect of importing torchrl. + if name in { + "AsyncVLLM", + "make_async_vllm_engine", + "make_vllm_worker", + "stateless_init_process_group", + "stateless_init_process_group_async", + }: + from . import backends # local import is intentional / required + + return getattr(backends, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/torchrl/modules/llm/backends/__init__.py b/torchrl/modules/llm/backends/__init__.py index 2c8226824fa..206f12c2f61 100644 --- a/torchrl/modules/llm/backends/__init__.py +++ b/torchrl/modules/llm/backends/__init__.py @@ -2,25 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""LLM backends. + +These backends can be optional and may rely on native extensions. We avoid +importing them at module import time and lazily load on attribute access. +""" + from __future__ import annotations -# Import everything from the vllm subfolder for backwards compatibility -from .vllm import ( - # Asynchronous vLLM - _AsyncLLMEngine, - _AsyncvLLMWorker, - AsyncVLLM, - # Synchronous vLLM - LocalLLMWrapper, - make_async_vllm_engine, - make_vllm_worker, - RayLLMWorker, - # Base classes and interfaces - RLvLLMEngine, - # Utilities - stateless_init_process_group, - stateless_init_process_group_async, -) +from typing import Any __all__ = [ # Base classes @@ -38,3 +28,38 @@ "stateless_init_process_group", "stateless_init_process_group_async", ] + +_LAZY_ATTRS: dict[str, tuple[str, str]] = { + # Base classes and interfaces + "RLvLLMEngine": ("torchrl.modules.llm.backends.vllm", "RLvLLMEngine"), + # Sync vLLM + "make_vllm_worker": ("torchrl.modules.llm.backends.vllm", "make_vllm_worker"), + "RayLLMWorker": ("torchrl.modules.llm.backends.vllm", "RayLLMWorker"), + "LocalLLMWrapper": ("torchrl.modules.llm.backends.vllm", "LocalLLMWrapper"), + # Async vLLM + "_AsyncvLLMWorker": ("torchrl.modules.llm.backends.vllm", "_AsyncvLLMWorker"), + "_AsyncLLMEngine": ("torchrl.modules.llm.backends.vllm", "_AsyncLLMEngine"), + "AsyncVLLM": ("torchrl.modules.llm.backends.vllm", "AsyncVLLM"), + "make_async_vllm_engine": ( + "torchrl.modules.llm.backends.vllm", + "make_async_vllm_engine", + ), + # Utilities + "stateless_init_process_group": ( + "torchrl.modules.llm.backends.vllm", + "stateless_init_process_group", + ), + "stateless_init_process_group_async": ( + "torchrl.modules.llm.backends.vllm", + "stateless_init_process_group_async", + ), +} + + +def __getattr__(name: str) -> Any: # noqa: ANN401 + target = _LAZY_ATTRS.get(name) + if target is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_name, attr_name = target + module = __import__(module_name, fromlist=[attr_name]) + return getattr(module, attr_name) diff --git a/torchrl/modules/llm/backends/vllm/__init__.py b/torchrl/modules/llm/backends/vllm/__init__.py index 85508483896..7b86eb0a84f 100644 --- a/torchrl/modules/llm/backends/vllm/__init__.py +++ b/torchrl/modules/llm/backends/vllm/__init__.py @@ -27,22 +27,7 @@ from __future__ import annotations -# Base classes and interfaces -from .base import RLvLLMEngine - -# Asynchronous vLLM -from .vllm_async import ( - _AsyncLLMEngine, - _AsyncvLLMWorker, - AsyncVLLM, - make_async_vllm_engine, -) - -# Synchronous vLLM -from .vllm_sync import LocalLLMWrapper, make_vllm_worker, RayLLMWorker - -# Shared utilities -from .vllm_utils import stateless_init_process_group, stateless_init_process_group_async +from typing import Any __all__ = [ # Base classes and interfaces @@ -60,3 +45,50 @@ "stateless_init_process_group", "stateless_init_process_group_async", ] + +_LAZY_ATTRS: dict[str, tuple[str, str]] = { + # Base + "RLvLLMEngine": ("torchrl.modules.llm.backends.vllm.base", "RLvLLMEngine"), + # Sync + "make_vllm_worker": ( + "torchrl.modules.llm.backends.vllm.vllm_sync", + "make_vllm_worker", + ), + "RayLLMWorker": ("torchrl.modules.llm.backends.vllm.vllm_sync", "RayLLMWorker"), + "LocalLLMWrapper": ( + "torchrl.modules.llm.backends.vllm.vllm_sync", + "LocalLLMWrapper", + ), + # Async + "_AsyncLLMEngine": ( + "torchrl.modules.llm.backends.vllm.vllm_async", + "_AsyncLLMEngine", + ), + "_AsyncvLLMWorker": ( + "torchrl.modules.llm.backends.vllm.vllm_async", + "_AsyncvLLMWorker", + ), + "AsyncVLLM": ("torchrl.modules.llm.backends.vllm.vllm_async", "AsyncVLLM"), + "make_async_vllm_engine": ( + "torchrl.modules.llm.backends.vllm.vllm_async", + "make_async_vllm_engine", + ), + # Utils + "stateless_init_process_group": ( + "torchrl.modules.llm.backends.vllm.vllm_utils", + "stateless_init_process_group", + ), + "stateless_init_process_group_async": ( + "torchrl.modules.llm.backends.vllm.vllm_utils", + "stateless_init_process_group_async", + ), +} + + +def __getattr__(name: str) -> Any: # noqa: ANN401 + target = _LAZY_ATTRS.get(name) + if target is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_name, attr_name = target + module = __import__(module_name, fromlist=[attr_name]) + return getattr(module, attr_name) diff --git a/torchrl/modules/llm/policies/__init__.py b/torchrl/modules/llm/policies/__init__.py index a29c370578d..9d25739a25a 100644 --- a/torchrl/modules/llm/policies/__init__.py +++ b/torchrl/modules/llm/policies/__init__.py @@ -2,12 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""LLM policy wrappers. + +This subpackage includes optional wrappers that may rely on native extensions +(e.g. vLLM). To avoid importing optional dependencies at module import time, +we avoid importing those dependencies at module import time. +""" from __future__ import annotations from .common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens from .transformers_wrapper import RemoteTransformersWrapper, TransformersWrapper - from .vllm_wrapper import vLLMWrapper __all__ = [ diff --git a/torchrl/modules/llm/policies/vllm_wrapper.py b/torchrl/modules/llm/policies/vllm_wrapper.py index 577029f0325..c77b5257ae4 100644 --- a/torchrl/modules/llm/policies/vllm_wrapper.py +++ b/torchrl/modules/llm/policies/vllm_wrapper.py @@ -5,9 +5,11 @@ from __future__ import annotations import collections + +import importlib.util import threading import warnings -from typing import Any, Literal +from typing import Any, Literal, TYPE_CHECKING import torch from tensordict import ( @@ -25,7 +27,6 @@ from torch.nn.utils.rnn import pad_sequence from torchrl.envs.utils import _classproperty -from torchrl.modules.llm.backends.vllm import AsyncVLLM from torchrl.modules.llm.policies.common import ( _batching, _extract_responses_from_full_histories, @@ -38,18 +39,40 @@ ) from torchrl.modules.utils.utils import _unpad_tensors -# Type imports -try: - import transformers - import vllm - from vllm.outputs import RequestOutput - from vllm.sampling_params import SamplingParams -except ImportError: - vllm = None - transformers = None + +_HAS_VLLM = importlib.util.find_spec("vllm") is not None +_HAS_TRANSFORMERS = importlib.util.find_spec("transformers") is not None + +if TYPE_CHECKING: + from vllm.outputs import RequestOutput # type: ignore[import-not-found] + from vllm.sampling_params import SamplingParams # type: ignore[import-not-found] +else: SamplingParams = Any # type: ignore RequestOutput = Any # type: ignore + +def _require_transformers() -> None: + if not _HAS_TRANSFORMERS: + raise ImportError( + "transformers is required for vLLMWrapper. Please install it with `pip install transformers`." + ) + + +def _require_vllm(): + """Import vLLM lazily. + + We intentionally avoid importing vLLM at module import time because importing vLLM can + load native extensions that may hard-crash the interpreter on some platforms. + """ + if not _HAS_VLLM: + raise ImportError( + "vllm is required for vLLMWrapper. Please install it with `pip install vllm`." + ) + import vllm as _vllm # local import is intentional / required + + return _vllm + + # Import async vLLM engines @@ -321,19 +344,26 @@ def __init__( else: self._batching_lock = None - if vllm is None: - raise ImportError("vllm is required for vLLMWrapper") - if transformers is None: - raise ImportError("transformers is required for vLLMWrapper") + _require_transformers() # Detect and initialize model if isinstance(model, str): + # Import lazily to avoid importing vLLM backends unless actually needed. + from torchrl.modules.llm.backends.vllm import ( # local import is intentional / required + AsyncVLLM, + ) + model = AsyncVLLM.from_pretrained(model) # Validate model type - if isinstance(model, AsyncVLLM): + model_type = type(model) + model_module = getattr(model_type, "__module__", "") + model_name = getattr(model_type, "__name__", "") + if model_name == "AsyncVLLM" and model_module.startswith( + "torchrl.modules.llm.backends.vllm" + ): self._model_type = "async_vllm" - elif vllm is not None and isinstance(model, vllm.LLM): + elif model_name == "LLM" and model_module.startswith("vllm"): self._model_type = "sync_vllm" elif hasattr(model, "generate") and hasattr(model, "remote"): # Ray actor with generate method @@ -347,8 +377,10 @@ def __init__( from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer) - - from vllm import SamplingParams + # Import vLLM lazily: only needed if we are going to interact with vLLM types. + # (This keeps importing this module safe even if vLLM hard-crashes on import.) + if self._model_type in ("sync_vllm",): + _require_vllm() # Validate input_mode if input_mode not in ["history", "text", "tokens"]: @@ -1918,12 +1950,11 @@ def _to_list( @_classproperty def CompletionOutput_tc(cls): - if vllm is None: - raise ImportError("vllm is required for CompletionOutput_tc") + _vllm = _require_vllm() if hasattr(cls, "_CompletionOutput_tc"): return cls._CompletionOutput_tc - CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) # type: ignore + CompletionOutput_tc = from_dataclass(_vllm.outputs.CompletionOutput) # type: ignore cls._CompletionOutput_tc = CompletionOutput_tc return CompletionOutput_tc