Skip to content

Commit

Permalink
2024-08-20 nightly release (9e78fe2)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Aug 20, 2024
1 parent d3585f4 commit 311f28f
Show file tree
Hide file tree
Showing 18 changed files with 246 additions and 78 deletions.
15 changes: 15 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ option(WITH_JPEG "Enable features requiring LibJPEG." ON)
# untested. Since building from cmake is very low pri anyway, this is OK. If
# you're a user and you need this, please open an issue (and a PR!).
option(WITH_WEBP "Enable features requiring LibWEBP." OFF)
# Same here
option(WITH_AVIF "Enable features requiring LibAVIF." OFF)

if(WITH_CUDA)
enable_language(CUDA)
Expand Down Expand Up @@ -41,6 +43,11 @@ if (WITH_WEBP)
find_package(WEBP REQUIRED)
endif()

if (WITH_AVIF)
add_definitions(-DAVIF_FOUND)
find_package(AVIF REQUIRED)
endif()

function(CUDA_CONVERT_FLAGS EXISTING_TARGET)
get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS)
if(NOT "${old_flags}" STREQUAL "")
Expand Down Expand Up @@ -117,6 +124,10 @@ if (WITH_WEBP)
target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES})
endif()

if (WITH_AVIF)
target_link_libraries(${PROJECT_NAME} PRIVATE ${AVIF_LIBRARIES})
endif()

set_target_properties(${PROJECT_NAME} PROPERTIES
EXPORT_NAME TorchVision
INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib)
Expand All @@ -135,6 +146,10 @@ if (WITH_WEBP)
include_directories(${WEBP_INCLUDE_DIRS})
endif()

if (WITH_AVIF)
include_directories(${AVIF_INCLUDE_DIRS})
endif()

set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake")

configure_package_config_file(cmake/TorchVisionConfig.cmake.in
Expand Down
137 changes: 73 additions & 64 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
USE_FFMPEG = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"
USE_VIDEO_CODEC = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1"
# Note: the GPU video decoding stuff used to be called "video codec", which
# isn't an accurate or descriptive name considering there are at least 2 other
# video deocding backends in torchvision. I'm renaming this to "gpu video
# decoder" where possible, keeping user facing names (like the env var below) to
# the old scheme for BC.
USE_GPU_VIDEO_DECODER = os.getenv("TORCHVISION_USE_VIDEO_CODEC", "1") == "1"
# Same here: "use ffmpeg" was used to denote "use cpu video decoder".
USE_CPU_VIDEO_DECODER = os.getenv("TORCHVISION_USE_FFMPEG", "1") == "1"

TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "")
TORCHVISION_LIBRARY = os.environ.get("TORCHVISION_LIBRARY", "")
Expand All @@ -43,10 +50,11 @@
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_AVIF = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{USE_FFMPEG = }")
print(f"{USE_VIDEO_CODEC = }")
print(f"{USE_CPU_VIDEO_DECODER = }")
print(f"{USE_GPU_VIDEO_DECODER = }")
print(f"{TORCHVISION_INCLUDE = }")
print(f"{TORCHVISION_LIBRARY = }")
print(f"{IS_ROCM = }")
Expand Down Expand Up @@ -326,6 +334,21 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_AVIF:
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if avif_found:
print("Building torchvision with AVIF support")
print(f"{avif_include_dir = }")
print(f"{avif_library_dir = }")
if avif_include_dir is not None and avif_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(avif_include_dir)
library_dirs.append(avif_library_dir)
libraries.append("avif")
define_macros += [("AVIF_FOUND", 1)]
else:
warnings.warn("Building torchvision without AVIF support")

if USE_NVJPEG and torch.cuda.is_available():
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

Expand All @@ -351,28 +374,21 @@ def make_image_extension():
def make_video_decoders_extensions():
print("Building video decoder extensions")

# Locating ffmpeg
ffmpeg_exe = shutil.which("ffmpeg")
has_ffmpeg = ffmpeg_exe is not None
ffmpeg_version = None
# FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
# FIXME: causes crash. See the following GitHub issues for more details.
# FIXME: https://github.com/pytorch/pytorch/issues/65000
# FIXME: https://github.com/pytorch/vision/issues/3367
build_without_extensions_msg = "Building without video decoders extensions."
if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9):
has_ffmpeg = False
if has_ffmpeg:
try:
# This is to check if ffmpeg is installed properly.
ffmpeg_version = subprocess.check_output(["ffmpeg", "-version"])
except subprocess.CalledProcessError:
print("Building torchvision without ffmpeg support")
print(" Error fetching ffmpeg version, ignoring ffmpeg.")
has_ffmpeg = False
# FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9
# FIXME: causes crash. See the following GitHub issues for more details.
# FIXME: https://github.com/pytorch/pytorch/issues/65000
# FIXME: https://github.com/pytorch/vision/issues/3367
print("Can only build video decoder extensions on linux and Python != 3.9")
return []

use_ffmpeg = USE_FFMPEG and has_ffmpeg
ffmpeg_exe = shutil.which("ffmpeg")
if ffmpeg_exe is None:
print(f"{build_without_extensions_msg} Couldn't find ffmpeg binary.")
return []

if use_ffmpeg:
def find_ffmpeg_libraries():
ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"}

ffmpeg_bin = os.path.dirname(ffmpeg_exe)
Expand All @@ -399,18 +415,23 @@ def make_video_decoders_extensions():
library_found |= len(glob.glob(full_path)) > 0

if not library_found:
print("Building torchvision without ffmpeg support")
print(f" {library} header files were not found, disabling ffmpeg support")
use_ffmpeg = False
else:
print("Building torchvision without ffmpeg support")
print(f"{build_without_extensions_msg}")
print(f"{library} header files were not found.")
return None, None

return ffmpeg_include_dir, ffmpeg_library_dir

ffmpeg_include_dir, ffmpeg_library_dir = find_ffmpeg_libraries()
if ffmpeg_include_dir is None or ffmpeg_library_dir is None:
return []

print("Found ffmpeg:")
print(f" ffmpeg include path: {ffmpeg_include_dir}")
print(f" ffmpeg library_dir: {ffmpeg_library_dir}")

extensions = []
if use_ffmpeg:
print("Building torchvision with ffmpeg support")
print(f" ffmpeg version: {ffmpeg_version}")
print(f" ffmpeg include path: {ffmpeg_include_dir}")
print(f" ffmpeg library_dir: {ffmpeg_library_dir}")
if USE_CPU_VIDEO_DECODER:
print("Building with CPU video decoder support")

# TorchVision base decoder + video reader
video_reader_src_dir = os.path.join(ROOT_DIR, "torchvision", "csrc", "io", "video_reader")
Expand All @@ -427,6 +448,7 @@ def make_video_decoders_extensions():

extensions.append(
CppExtension(
# This is an aweful name. It should be "cpu_video_decoder". Keeping for BC.
"torchvision.video_reader",
combined_src,
include_dirs=[
Expand All @@ -450,25 +472,24 @@ def make_video_decoders_extensions():
)
)

# Locating video codec
# CUDA_HOME should be set to the cuda root directory.
# TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to
# video codec header files and libraries respectively.
video_codec_found = (
BUILD_CUDA_SOURCES
and CUDA_HOME is not None
and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in TORCHVISION_INCLUDE])
and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in TORCHVISION_INCLUDE])
and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in TORCHVISION_LIBRARY])
)
if USE_GPU_VIDEO_DECODER:
# Locating GPU video decoder headers and libraries
# CUDA_HOME should be set to the cuda root directory.
# TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the locations
# to the headers and libraries below
if not (
BUILD_CUDA_SOURCES
and CUDA_HOME is not None
and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in TORCHVISION_INCLUDE])
and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in TORCHVISION_INCLUDE])
and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in TORCHVISION_LIBRARY])
and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print("Could not find necessary dependencies. Refer the setup.py to check which ones are needed.")
print("Building without GPU video decoder support")
return extensions
print("Building torchvision with GPU video decoder support")

use_video_codec = USE_VIDEO_CODEC and video_codec_found
if (
use_video_codec
and use_ffmpeg
and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print("Building torchvision with video codec support")
gpu_decoder_path = os.path.join(CSRS_DIR, "io", "decoder", "gpu")
gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp"))
cuda_libs = os.path.join(CUDA_HOME, "lib64")
Expand All @@ -477,7 +498,7 @@ def make_video_decoders_extensions():
_, extra_compile_args = get_macros_and_flags()
extensions.append(
CUDAExtension(
"torchvision.Decoder",
"torchvision.gpu_decoder",
gpu_decoder_src,
include_dirs=[CSRS_DIR] + TORCHVISION_INCLUDE + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir,
library_dirs=ffmpeg_library_dir + TORCHVISION_LIBRARY + [cuda_libs],
Expand All @@ -498,18 +519,6 @@ def make_video_decoders_extensions():
extra_compile_args=extra_compile_args,
)
)
else:
print("Building torchvision without video codec support")
if (
use_video_codec
and use_ffmpeg
and not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir])
):
print(
" The installed version of ffmpeg is missing the header file 'bsf.h' which is "
" required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:"
" `conda install -c conda-forge ffmpeg`."
)

return extensions

Expand Down
Binary file added test/assets/fakedata/logos/rgb_pytorch.avif
Binary file not shown.
15 changes: 14 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_decode_avif,
decode_gif,
decode_image,
decode_jpeg,
Expand Down Expand Up @@ -873,7 +874,7 @@ def test_decode_gif_webp_errors(decode_fun):
decode_fun(encoded_data[::2])
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
else:
elif decode_fun is decode_webp:
expected_match = "WebPDecodeRGB failed."
with pytest.raises(RuntimeError, match=expected_match):
decode_fun(encoded_data)
Expand All @@ -890,5 +891,17 @@ def test_decode_webp(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)


if __name__ == "__main__":
pytest.main([__file__])
6 changes: 3 additions & 3 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,


@pytest.mark.skipif(
get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, reason="video_reader backend not available"
get_video_backend() != "pyav" and not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend not available"
)
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
class TestVideo:
Expand All @@ -77,14 +77,14 @@ def test_write_read_video(self):
assert_equal(data, lv)
assert info["video_fps"] == 5

@pytest.mark.skipif(not io._HAS_VIDEO_OPT, reason="video_reader backend is not chosen")
@pytest.mark.skipif(not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend is not chosen")
def test_probe_video_from_file(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
video_info = io._probe_video_from_file(f_name)
assert pytest.approx(2, rel=0.0, abs=0.1) == video_info.video_duration
assert pytest.approx(5, rel=0.0, abs=0.1) == video_info.video_fps

@pytest.mark.skipif(not io._HAS_VIDEO_OPT, reason="video_reader backend is not chosen")
@pytest.mark.skipif(not io._HAS_CPU_VIDEO_DECODER, reason="video_reader backend is not chosen")
def test_probe_video_from_memory(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
with open(f_name, "rb") as fp:
Expand Down
4 changes: 2 additions & 2 deletions test/test_video_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy.random import randint
from pytest import approx
from torchvision import set_video_backend
from torchvision.io import _HAS_VIDEO_OPT
from torchvision.io import _HAS_CPU_VIDEO_DECODER


try:
Expand Down Expand Up @@ -263,7 +263,7 @@ def _get_video_tensor(video_dir, video_file):


@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg")
class TestVideoReader:
def check_separate_decoding_result(self, tv_result, config):
"""check the decoding results from TorchVision decoder"""
Expand Down
4 changes: 2 additions & 2 deletions test/test_videoapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torchvision
from pytest import approx
from torchvision.datasets.utils import download_url
from torchvision.io import _HAS_VIDEO_OPT, VideoReader
from torchvision.io import _HAS_CPU_VIDEO_DECODER, VideoReader


# WARNING: these tests have been skipped forever on the CI because the video ops
Expand Down Expand Up @@ -62,7 +62,7 @@ def fate(name, path="."):
}


@pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg")
@pytest.mark.skipif(_HAS_CPU_VIDEO_DECODER is False, reason="Didn't compile with ffmpeg")
class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys())
Expand Down
2 changes: 1 addition & 1 deletion torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def set_video_backend(backend):
global _video_backend
if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
if backend == "video_reader" and not io._HAS_CPU_VIDEO_DECODER:
# TODO: better messages
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
raise RuntimeError(message)
Expand Down
Loading

0 comments on commit 311f28f

Please sign in to comment.