Skip to content

Commit

Permalink
revamp setup
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 16, 2024
1 parent 33e47d8 commit 2f2dc7b
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/unittest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ echo '::endgroup::'
python test/smoke_test.py

# We explicitly ignore the video tests until we resolve https://github.com/pytorch/vision/issues/8162
pytest --ignore-glob="*test_video*" --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v --durations=25
pytest test/test_ops.py test/test_image.py --junit-xml="${RUNNER_TEST_RESULTS_DIR}/test-results.xml" -v --durations=25
198 changes: 196 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import shutil
import subprocess
import sys
import warnings
from pathlib import Path

import torch
from pkg_resources import DistributionNotFound, get_distribution, parse_version
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDA_HOME, CUDAExtension
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDA_HOME, CUDAExtension, ROCM_HOME


def read(*names, **kwargs):
Expand Down Expand Up @@ -124,6 +126,193 @@ def find_library(name, vision_include):
return library_found, conda_installed, include_folder, lib_folder


FORCE_CUDA = os.getenv("FORCE_CUDA", "0") == "1"
print(f"{FORCE_CUDA = }")
FORCE_MPS = os.getenv("FORCE_MPS", "0") == "1"
print(f"{FORCE_MPS = }")
DEBUG = os.getenv("DEBUG", "0") == "1"
print(f"{DEBUG = }")
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
print(f"{USE_PNG = }")
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
print(f"{USE_JPEG = }")

ROOT_DIR = Path(__file__).absolute().parent
CSRS_DIR = ROOT_DIR / "torchvision/csrc"
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
BUILD_CUDA_SOURCES = (torch.cuda.is_available() and ((CUDA_HOME is not None) or IS_ROCM)) or FORCE_CUDA


def get_macros_and_flags():
define_macros = []
extra_compile_args = {"cxx": []}
if BUILD_CUDA_SOURCES:
if IS_ROCM:
define_macros += [("WITH_HIP", None)]
nvcc_flags = []
else:
define_macros += [("WITH_CUDA", None)]
if nvcc_flags == "":
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(" ")
extra_compile_args["nvcc"] = nvcc_flags

if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)]
extra_compile_args["cxx"].append("/MP")

if DEBUG:
extra_compile_args["cxx"].append("-g")
extra_compile_args["cxx"].append("-O0")
if "nvcc" in extra_compile_args:
# we have to remove "-OX" and "-g" flag if exists and append
nvcc_flags = extra_compile_args["nvcc"]
extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)]
extra_compile_args["nvcc"].append("-O0")
extra_compile_args["nvcc"].append("-g")
else:
extra_compile_args["cxx"].append("-g0")

return define_macros, extra_compile_args


def make_C_extension():

common_sources = (
list(CSRS_DIR.glob("*.cpp")) + list(CSRS_DIR.glob("ops/*.cpp")) + list(CSRS_DIR.glob("ops/autocast/*.cpp"))
)
cpu_sources = (
list(CSRS_DIR.glob("ops/autograd/*.cpp"))
+ list(CSRS_DIR.glob("ops/cpu/*.cpp"))
+ list(CSRS_DIR.glob("ops/quantized/cpu/*.cpp"))
)
mps_sources = list(CSRS_DIR.glob("ops/mps/*.mm"))

if IS_ROCM:
from torch.utils.hipify import hipify_python

hipify_python.hipify(
project_directory=str(ROOT_DIR),
output_directory=str(ROOT_DIR),
includes="torchvision/csrc/ops/cuda/*",
show_detailed=True,
is_pytorch_extension=True,
)
cuda_sources = list(CSRS_DIR.glob("ops/hip/*.hip"))
for header in CSRS_DIR.glob("ops/cuda/*.h"):
shutil.copy(str(header), str(CSRS_DIR / "ops/hip"))
else:
cuda_sources = list(CSRS_DIR.glob("ops/cuda/*.cu"))

sources = common_sources + cpu_sources

if BUILD_CUDA_SOURCES:
Extension = CUDAExtension
sources += cuda_sources
else:
Extension = CppExtension
if torch.backends.mps.is_available() or FORCE_MPS:
sources += mps_sources

define_macros, extra_compile_args = get_macros_and_flags()
return Extension(
name="torchvision._C",
sources=sorted(str(s) for s in sources),
include_dirs=[CSRS_DIR],
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)


def find_libpng():
if sys.platform in ("linux", "darwin"):
libpng_config = shutil.which("libpng-config")
if libpng_config is None:
warnings.warn("libpng-config not found")
return False, None, None, None
min_version = parse_version("1.6.0")
png_version = parse_version(
subprocess.run([libpng_config, "--version"], stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
)
if png_version < min_version:
warnings.warn("libpng version {png_version} is less than minimum required version {min_version}")
return False, None, None, None

include_dir = (
subprocess.run([libpng_config, "--I_opts"], stdout=subprocess.PIPE)
.stdout.strip()
.decode("utf-8")
.split("-I")[1]
)
library_dir = subprocess.run([libpng_config, "--libdir"], stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
library = "png"
else: # Windows
pngfix = shutil.which("pngfix")
if pngfix is None:
warnings.warn("pngfix not found")
return False, None, None, None
pngfix_dir = Path(pngfix).absolute().parent

library_dir = str(pngfix_dir / "lib")
include_dir = str(pngfix_dir / "include", "libpng16")
library = "libpng"

return True, include_dir, library_dir, library


def make_image_extension():
include_dirs = os.environ.get("TORCHVISION_INCLUDE", "")
library_dirs = os.environ.get("TORCHVISION_LIBRARY", "")
include_dirs = include_dirs.split(os.pathsep) if include_dirs else []
library_dirs = library_dirs.split(os.pathsep) if library_dirs else []

libraries = []
define_macros, extra_compile_args = get_macros_and_flags()

image_dir = CSRS_DIR / "io/image"
sources = list(image_dir.glob("*.cpp")) + list(image_dir.glob("cpu/*.cpp")) + list(image_dir.glob("cpu/giflib/*.c"))

if IS_ROCM:
sources += list(image_dir.glob("hip/*.cpp"))
# we need to exclude this in favor of the hipified source
sources.remove(image_dir / "image.cpp")
else:
sources += list(image_dir.glob("cuda/*.cpp"))

if USE_PNG:
png_found, png_include_dir, png_library_dir, png_library = find_libpng()
if png_found:
print("Building torchvision with PNG support")
include_dirs.append(png_include_dir)
library_dirs.append(png_library_dir)
libraries.append(png_library)
define_macros += [("PNG_FOUND", 1)]
else:
warnings.warn("Building torchvision without PNG support")


if USE_JPEG:
jpeg_found, jpeg_conda, jpeg_include, jpeg_lib = find_library("jpeglib", include_dirs)

if jpeg_found:
libraries.append("jpeg")
define_macros += [("JPEG_FOUND", 1)]
if jpeg_conda:
include_dirs.append(jpeg_include)
library_dirs.append(jpeg_lib)

return CppExtension(
name="torchvision.image",
sources=sorted(str(s) for s in sources),
include_dirs=include_dirs,
library_dirs=library_dirs,
define_macros=define_macros,
libraries=libraries,
extra_compile_args=extra_compile_args,
)


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
Expand Down Expand Up @@ -537,6 +726,11 @@ def run(self):
with open("README.md") as f:
readme = f.read()

extensions = get_extensions()
print(len(extensions))
extensions[0] = make_C_extension()
extensions[1] = make_image_extension()

setup(
# Metadata
name=package_name,
Expand All @@ -557,7 +751,7 @@ def run(self):
"gdown": ["gdown>=4.7.3"],
"scipy": ["scipy"],
},
ext_modules=get_extensions(),
ext_modules=extensions,
python_requires=">=3.8",
cmdclass={
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
Expand Down

0 comments on commit 2f2dc7b

Please sign in to comment.