Skip to content

Commit

Permalink
Add AVIF decoder (Part 1- this is not public or available yet) (#8596)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 19, 2024
1 parent 0a0f34b commit 9e78fe2
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 1 deletion.
15 changes: 15 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ option(WITH_JPEG "Enable features requiring LibJPEG." ON)
# untested. Since building from cmake is very low pri anyway, this is OK. If
# you're a user and you need this, please open an issue (and a PR!).
option(WITH_WEBP "Enable features requiring LibWEBP." OFF)
# Same here
option(WITH_AVIF "Enable features requiring LibAVIF." OFF)

if(WITH_CUDA)
enable_language(CUDA)
Expand Down Expand Up @@ -41,6 +43,11 @@ if (WITH_WEBP)
find_package(WEBP REQUIRED)
endif()

if (WITH_AVIF)
add_definitions(-DAVIF_FOUND)
find_package(AVIF REQUIRED)
endif()

function(CUDA_CONVERT_FLAGS EXISTING_TARGET)
get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS)
if(NOT "${old_flags}" STREQUAL "")
Expand Down Expand Up @@ -117,6 +124,10 @@ if (WITH_WEBP)
target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES})
endif()

if (WITH_AVIF)
target_link_libraries(${PROJECT_NAME} PRIVATE ${AVIF_LIBRARIES})
endif()

set_target_properties(${PROJECT_NAME} PROPERTIES
EXPORT_NAME TorchVision
INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib)
Expand All @@ -135,6 +146,10 @@ if (WITH_WEBP)
include_directories(${WEBP_INCLUDE_DIRS})
endif()

if (WITH_AVIF)
include_directories(${AVIF_INCLUDE_DIRS})
endif()

set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake")

configure_package_config_file(cmake/TorchVisionConfig.cmake.in
Expand Down
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
# Note: the GPU video decoding stuff used to be called "video codec", which
Expand Down Expand Up @@ -49,6 +50,7 @@
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_AVIF = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{USE_CPU_VIDEO_DECODER = }")
Expand Down Expand Up @@ -332,6 +334,21 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_AVIF:
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if avif_found:
print("Building torchvision with AVIF support")
print(f"{avif_include_dir = }")
print(f"{avif_library_dir = }")
if avif_include_dir is not None and avif_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(avif_include_dir)
library_dirs.append(avif_library_dir)
libraries.append("avif")
define_macros += [("AVIF_FOUND", 1)]
else:
warnings.warn("Building torchvision without AVIF support")

if USE_NVJPEG and torch.cuda.is_available():
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

Expand Down
Binary file added test/assets/fakedata/logos/rgb_pytorch.avif
Binary file not shown.
15 changes: 14 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_decode_avif,
decode_gif,
decode_image,
decode_jpeg,
Expand Down Expand Up @@ -873,7 +874,7 @@ def test_decode_gif_webp_errors(decode_fun):
decode_fun(encoded_data[::2])
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
else:
elif decode_fun is decode_webp:
expected_match = "WebPDecodeRGB failed."
with pytest.raises(RuntimeError, match=expected_match):
decode_fun(encoded_data)
Expand All @@ -890,5 +891,17 @@ def test_decode_webp(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)


if __name__ == "__main__":
pytest.main([__file__])
92 changes: 92 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_avif.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "decode_avif.h"

#if AVIF_FOUND
#include "avif/avif.h"
#endif // AVIF_FOUND

namespace vision {
namespace image {

#if !AVIF_FOUND
torch::Tensor decode_avif(const torch::Tensor& data) {
TORCH_CHECK(
false, "decode_avif: torchvision not compiled with libavif support");
}
#else

// This normally comes from avif_cxx.h, but it's not always present when
// installing libavif. So we just copy/paste it here.
struct UniquePtrDeleter {
void operator()(avifDecoder* decoder) const {
avifDecoderDestroy(decoder);
}
};
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>;

torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
// This is based on
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c
// Refer there for more detail about what each function does, and which
// structure/data is available after which call.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

DecoderPtr decoder(avifDecoderCreate());
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");

auto result = AVIF_RESULT_UNKNOWN_ERROR;
result = avifDecoderSetIOMemory(
decoder.get(), encoded_data.data_ptr<uint8_t>(), encoded_data.numel());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderSetIOMemory failed:",
avifResultToString(result));

result = avifDecoderParse(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderParse failed: ",
avifResultToString(result));
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");
TORCH_CHECK(
decoder->image->depth <= 8,
"avif images with bitdepth > 8 are not supported");

result = avifDecoderNextImage(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderNextImage failed:",
avifResultToString(result));

auto out = torch::empty(
{decoder->image->height, decoder->image->width, 3}, torch::kUInt8);

avifRGBImage rgb;
memset(&rgb, 0, sizeof(rgb));
avifRGBImageSetDefaults(&rgb, decoder->image);
rgb.format = AVIF_RGB_FORMAT_RGB;
rgb.pixels = out.data_ptr<uint8_t>();
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb);

result = avifImageYUVToRGB(decoder->image, &rgb);
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifImageYUVToRGB failed: ",
avifResultToString(result));

return out.permute({2, 0, 1}); // return CHW, channels-last
}
#endif // AVIF_FOUND

} // namespace image
} // namespace vision
11 changes: 11 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_avif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <torch/types.h>

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data);

} // namespace image
} // namespace vision
13 changes: 13 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_image.h"

#include "decode_avif.h"
#include "decode_gif.h"
#include "decode_jpeg.h"
#include "decode_png.h"
Expand Down Expand Up @@ -48,6 +49,18 @@ torch::Tensor decode_image(
return decode_gif(data);
}

// We assume the signature of an avif file is
// 0000 0020 6674 7970 6176 6966
// xxxx xxxx f t y p a v i f
// We only check for the "ftyp avif" part.
// This is probably not perfect, but hopefully this should cover most files.
const uint8_t avif_signature[8] = {
0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif"
TORCH_CHECK(data.numel() >= 12, err_msg);
if ((memcmp(avif_signature, datap + 4, 8) == 0)) {
return decode_avif(data);
}

const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
const uint8_t webp_signature_end[7] = {
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static auto registry =
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp", &decode_webp)
.op("image::decode_avif", &decode_avif)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "cpu/decode_avif.h"
#include "cpu/decode_gif.h"
#include "cpu/decode_image.h"
#include "cpu/decode_jpeg.h"
Expand Down
2 changes: 2 additions & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"decode_image",
"decode_jpeg",
"decode_png",
"decode_webp",
"decode_gif",
"encode_jpeg",
"encode_png",
"read_file",
Expand Down
8 changes: 8 additions & 0 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,11 @@ def decode_webp(
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_webp(input)


def _decode_avif(
input: torch.Tensor,
) -> torch.Tensor:
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_avif(input)

0 comments on commit 9e78fe2

Please sign in to comment.