Skip to content

Commit

Permalink
clean up and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 30, 2024
1 parent 6547d2c commit af7d711
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 16 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "1") == "1"
USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default!
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
Expand Down
17 changes: 7 additions & 10 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_decode_avif,
_decode_heic,
decode_gif,
decode_heic,
decode_image,
decode_jpeg,
decode_png,
Expand Down Expand Up @@ -890,7 +890,6 @@ def test_decode_webp(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # Make sure the underlying data wasn't accidentally freed


# This test is skipped because it requires webp images that we're not including
Expand Down Expand Up @@ -928,13 +927,12 @@ def test_decode_avif(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # Make sure the underlying data wasn't accidentally freed


# @pytest.mark.xfail(reason="AVIF support not enabled yet.")
# Note: decode_image fails because some of these files have a (valid) signature
# we don't recognize. We should probably use libmagic....
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_heic))
@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode",
Expand Down Expand Up @@ -981,9 +979,9 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
assert img.shape[0] == 3
if mode == ImageReadMode.RGB_ALPHA:
assert img.shape[0] == 4

if img.dtype == torch.uint16:
img = F.to_dtype(img, dtype=torch.uint8, scale=True)

try:
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
except RuntimeError as e:
Expand All @@ -998,22 +996,22 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
g = make_grid([img, from_pil])
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))

is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "decode_heic"
if mode == ImageReadMode.RGB and not is_decode_heic:
is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
if mode == ImageReadMode.RGB and not is__decode_heic:
# We don't compare torchvision's AVIF against PIL for RGB because
# results look pretty different on RGBA images (other images are fine).
# The result on torchvision basically just plainly ignores the alpha
# channel, resuting in transparent pixels looking dark. PIL seems to be
# using a sort of k-nn thing (Take a look at the resuting images)
return
if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic:
return

torch.testing.assert_close(img, from_pil, rtol=0, atol=3)


# @pytest.mark.xfail(reason="HEIC support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (decode_heic, decode_image))
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_heic(decode_fun, scripted):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
Expand All @@ -1022,7 +1020,6 @@ def test_decode_heic(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # Make sure the underlying data wasn't accidentally freed


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/cpu/decode_avif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ torch::Tensor decode_avif(
result == AVIF_RESULT_OK,
"avifDecoderParse failed: ",
avifResultToString(result));
printf("avif num images = %d\n", decoder->imageCount);
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");

Expand Down
30 changes: 29 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_heic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,34 @@ torch::Tensor decode_heic(
int bit_depth = 0;

try {
// TODO: error on image sequences
heif::Context ctx;
ctx.read_from_memory_without_copy(
encoded_data.data_ptr<uint8_t>(), encoded_data.numel());

// TODO properly error on (or support) image sequences. Right now, I think
// this function will always return the first image in a sequence, which is
// inconsistent with decode_gif (which returns a batch) and with decode_avif
// (which errors loudly).
// Why? I'm struggling to make sense of
// ctx.get_number_of_top_level_images(). It disagrees with libavif's
// imageCount. For example on some of the libavif test images:
//
// - colors-animated-12bpc-keyframes-0-2-3.avif
// avif num images = 5
// heif num images = 1 // Why is this 1 when clearly this is supposed to
// be a sequence?
// - sofa_grid1x5_420.avif
// avif num images = 1
// heif num images = 6 // If we were to error here we won't be able to
// decode this image which is otherwise properly
// decoded by libavif.
// I can't find a libheif function that does what we need here, or at least
// that agrees with libavif.

// TORCH_CHECK(
// ctx.get_number_of_top_level_images() == 1,
// "heic file contains more than one image");

heif::ImageHandle handle = ctx.get_primary_image_handle();
bit_depth = handle.get_luma_bits_per_pixel();

Expand All @@ -71,6 +94,8 @@ torch::Tensor decode_heic(
chroma = return_rgb ? heif_chroma_interleaved_RGB
: heif_chroma_interleaved_RGBA;
} else {
// TODO: This, along with our 10bits -> 16bits range mapping down below,
// may not work on BE platforms
chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE
: heif_chroma_interleaved_RRGGBBAA_LE;
}
Expand All @@ -79,6 +104,9 @@ torch::Tensor decode_heic(

decoded_data = img.get_plane(heif_channel_interleaved, &stride);
} catch (const heif::Error& err) {
// We need this try/catch block and call TORCH_CHECK, because libheif may
// otherwise throw heif::Error that would just be reported as "An unknown
// exception occurred" when we move back to Python.
TORCH_CHECK(false, "decode_heif failed: ", err.get_message());
}
TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding.");
Expand Down
4 changes: 3 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_heic.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_heic(const torch::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
C10_EXPORT torch::Tensor decode_heic(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
1 change: 0 additions & 1 deletion torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from .image import (
decode_gif,
decode_heic,
decode_image,
decode_jpeg,
decode_png,
Expand Down
24 changes: 22 additions & 2 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,27 @@ def _decode_avif(
return torch.ops.image.decode_avif(input, mode.value)


def decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
The values of the output tensor are in uint8 in [0, 255] for most images. If
the image has a bit-depth of more than 8, then the output tensor is uint16
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
``scale=True`` after this function to convert the decoded image into a uint8
or float tensor.
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the HEIC image.
mode (ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Returns:
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_heic)
_log_api_usage_once(_decode_heic)
return torch.ops.image.decode_heic(input, mode.value)

0 comments on commit af7d711

Please sign in to comment.