Skip to content

Commit

Permalink
Add nvjpeg support
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 16, 2024
1 parent 0e895e6 commit 4a74e88
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def find_libjpeg():
print(f"{USE_PNG = }")
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
print(f"{USE_JPEG = }")
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
print(f"{USE_NVJPEG = }")

TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "")
TORCHVISION_LIBRARY = os.environ.get("TORCHVISION_LIBRARY", "")
Expand Down Expand Up @@ -320,6 +322,8 @@ def make_image_extension():
sources.remove(image_dir / "image.cpp")
else:
sources += list(image_dir.glob("cuda/*.cpp"))

Extension = CppExtension

if USE_PNG:
png_found, png_include_dir, png_library_dir, png_library = find_libpng()
Expand Down Expand Up @@ -348,8 +352,20 @@ def make_image_extension():
define_macros += [("JPEG_FOUND", 1)]
else:
warnings.warn("Building torchvision without JPEG support")

if USE_NVJPEG and torch.cuda.is_available():
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

if nvjpeg_found:
print("Building torchvision with NVJPEG image support")
libraries.append("nvjpeg")
image_macros += [("NVJPEG_FOUND", 1)]
Extension = CUDAExtension
else:
warnings.warn("Building torchvision without NVJPEG support")


return CppExtension(
return Extension(
name="torchvision.image",
sources=sorted(str(s) for s in sources),
include_dirs=include_dirs,
Expand Down

0 comments on commit 4a74e88

Please sign in to comment.