diff --git a/.github/scripts/pre-build-script-win.sh b/.github/scripts/pre-build-script-win.sh index ae7786ca08a..85188385da3 100644 --- a/.github/scripts/pre-build-script-win.sh +++ b/.github/scripts/pre-build-script-win.sh @@ -1,5 +1,7 @@ #!/bin/bash pip install --upgrade setuptools +${CONDA_RUN} pip install "pybind11[global]" +${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps export TORCHRL_BUILD_VERSION=0.10.0 diff --git a/.github/scripts/td_script.sh b/.github/scripts/td_script.sh index cadc7d24f09..10537031554 100644 --- a/.github/scripts/td_script.sh +++ b/.github/scripts/td_script.sh @@ -3,27 +3,17 @@ export TORCHRL_BUILD_VERSION=0.10.0 pip install --upgrade setuptools +# Always install pybind11 - required for building C++ extensions +${CONDA_RUN} pip install "pybind11[global]" + # Check if ARCH is set to aarch64 ARCH=${ARCH:-} # This sets ARCH to an empty string if it's not defined if pip list | grep -q torch; then echo "Torch is installed." - - # ${CONDA_RUN} conda install 'anaconda::cmake>=3.22' -y - - ${CONDA_RUN} pip install "pybind11[global]" - ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps elif [[ -n "${SMOKE_TEST_SCRIPT:-}" ]]; then ${CONDA_RUN} ${PIP_INSTALL_TORCH} - # TODO: revert when nightlies of tensordict are fixed - # if [[ "$ARCH" == "aarch64" ]]; then - - -# ${CONDA_RUN} conda install 'anaconda::cmake>=3.22' -y - - ${CONDA_RUN} pip install "pybind11[global]" - ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U --no-deps else echo "Torch is not installed - tensordict will be installed later." diff --git a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml index 84cce31cfe6..7b8216d3828 100644 --- a/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml +++ b/.github/unittest/linux_olddeps/scripts_gym_0_13/environment.yml @@ -28,3 +28,4 @@ dependencies: - av - h5py - numpy<2.0.0 + - pybind11[global] diff --git a/build_nightly.sh b/build_nightly.sh index 890d4ba8cbe..1bd4b95fb4a 100755 --- a/build_nightly.sh +++ b/build_nightly.sh @@ -56,10 +56,10 @@ echo "Installing build dependencies..." PYTHON_VERSION=$(python -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')") if python -c "import sys; exit(0 if sys.version_info < (3, 11) else 1)"; then echo "Using setuptools 65.3.0 for Python $PYTHON_VERSION (compatibility mode)" - python -m pip install wheel setuptools==65.3.0 + python -m pip install wheel setuptools==65.3.0 "pybind11[global]" else echo "Using latest setuptools for Python $PYTHON_VERSION (modern mode)" - python -m pip install wheel setuptools + python -m pip install wheel setuptools "pybind11[global]" fi python setup.py bdist_wheel diff --git a/setup.py b/setup.py index f0e4408bbcd..6381cc57600 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,27 @@ logger = logging.getLogger(__name__) ROOT_DIR = Path(__file__).parent.resolve() + + +def _check_pybind11(): + """Check that pybind11 is installed and provide a clear error message if not. + + Only checks when actually building extensions, not for commands like 'clean'. + """ + # Commands that don't require building C++ extensions + skip_commands = {"clean", "egg_info", "sdist", "--version", "--help", "-h"} + if skip_commands.intersection(sys.argv): + return + if importlib.util.find_spec("pybind11") is None: + raise RuntimeError( + "pybind11 is required to build TorchRL's C++ extensions but was not found.\n" + "Please install it with:\n" + " pip install 'pybind11[global]'\n" + "Then re-run the installation." + ) + + +_check_pybind11() _RELEASE_BRANCH_RE = re.compile(r"^release/v(?P.+)$") _BUILD_INFO_FILE = ROOT_DIR / "build" / ".torchrl_build_info.json" @@ -47,13 +68,14 @@ def _check_and_clean_stale_builds(): f"Python {old_python} -> {current_python_version}. " f"Cleaning stale build artifacts..." ) - # Clean stale .so files for current Python version - so_pattern = ( + # Clean stale extension files for current Python version + ext = ".pyd" if sys.platform == "win32" else ".so" + ext_pattern = ( ROOT_DIR / "torchrl" - / f"_torchrl.cpython-{sys.version_info.major}{sys.version_info.minor}*.so" + / f"_torchrl.cpython-{sys.version_info.major}{sys.version_info.minor}*{ext}" ) - for so_file in glob.glob(str(so_pattern)): + for so_file in glob.glob(str(ext_pattern)): logger.warning(f"Removing stale: {so_file}") os.remove(so_file) # Clean build directory diff --git a/test/compile/test_collectors.py b/test/compile/test_compile_collectors.py similarity index 100% rename from test/compile/test_collectors.py rename to test/compile/test_compile_collectors.py diff --git a/test/compile/test_objectives.py b/test/compile/test_compile_objectives.py similarity index 100% rename from test/compile/test_objectives.py rename to test/compile/test_compile_objectives.py diff --git a/test/compile/test_utils.py b/test/compile/test_compile_utils.py similarity index 100% rename from test/compile/test_utils.py rename to test/compile/test_compile_utils.py diff --git a/test/compile/test_value.py b/test/compile/test_compile_value.py similarity index 100% rename from test/compile/test_value.py rename to test/compile/test_compile_value.py