Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup/refactor of decoders and related tests #8617

Merged
merged 22 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@
IS_WINDOWS = sys.platform in ("win32", "cygwin")
IS_MACOS = sys.platform == "darwin"
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
WEBP_TEST_IMAGES_DIR = os.environ.get("WEBP_TEST_IMAGES_DIR", "")

# Hacky way of figuring out whether we compiled with libavif/libheif (those are
# currenlty disabled by default)
try:
_decode_avif(torch.arange(10, dtype=torch.uint8))
except Exception as e:
DECODE_AVIF_ENABLED = "torchvision not compiled with libavif support" not in str(e)

try:
_decode_heic(torch.arange(10, dtype=torch.uint8))
except Exception as e:
DECODE_HEIC_ENABLED = "torchvision not compiled with libheif support" not in str(e)


def _get_safe_image_name(name):
Expand Down Expand Up @@ -149,17 +162,6 @@ def test_invalid_exif(tmpdir, size):
torch.testing.assert_close(expected, output)


def test_decode_jpeg_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))

with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(torch.empty((100,), dtype=torch.float16))

with pytest.raises(RuntimeError, match="Not a JPEG file"):
decode_jpeg(torch.empty((100), dtype=torch.uint8))


def test_decode_bad_huffman_images():
# sanity check: make sure we can decode the bad Huffman encoding
bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg"))
Expand Down Expand Up @@ -235,10 +237,6 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):


def test_decode_png_errors():
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_png(torch.empty((), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Content is not png"):
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
with pytest.raises(RuntimeError, match="Out of bound read in decode_png"):
decode_png(read_file(os.path.join(DAMAGED_PNG, "sigsegv.png")))
with pytest.raises(RuntimeError, match="Content is too small for png"):
Expand Down Expand Up @@ -864,20 +862,28 @@ def test_decode_gif(tmpdir, name, scripted):
torch.testing.assert_close(tv_frame, pil_frame, atol=0, rtol=0)


@pytest.mark.parametrize("decode_fun", (decode_gif, decode_webp))
def test_decode_gif_webp_errors(decode_fun):
decode_fun_and_match = [
(decode_png, "Content is not png"),
(decode_jpeg, "Not a JPEG file"),
(decode_gif, re.escape("DGifOpenFileName() failed - 103")),
(decode_webp, "WebPGetFeatures failed."),
]
if DECODE_AVIF_ENABLED:
decode_fun_and_match.append((_decode_avif, "BMFF parsing failed"))
if DECODE_HEIC_ENABLED:
decode_fun_and_match.append((_decode_heic, "Invalid input: No 'ftyp' box"))


@pytest.mark.parametrize("decode_fun, match", decode_fun_and_match)
def test_decode_bad_encoded_data(decode_fun, match):
encoded_data = torch.randint(0, 256, (100,), dtype=torch.uint8)
with pytest.raises(RuntimeError, match="Input tensor must be 1-dimensional"):
decode_fun(encoded_data[None])
with pytest.raises(RuntimeError, match="Input tensor must have uint8 data type"):
decode_fun(encoded_data.float())
with pytest.raises(RuntimeError, match="Input tensor must be contiguous"):
decode_fun(encoded_data[::2])
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
elif decode_fun is decode_webp:
expected_match = "WebPGetFeatures failed."
with pytest.raises(RuntimeError, match=expected_match):
with pytest.raises(RuntimeError, match=match):
decode_fun(encoded_data)


Expand All @@ -890,21 +896,27 @@ def test_decode_webp(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


# This test is skipped because it requires webp images that we're not including
# within the repo. The test images were downloaded from the different pages of
# https://developers.google.com/speed/webp/gallery
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
@pytest.mark.skip(reason="Need to download test images first")
# This test is skipped by default because it requires webp images that we're not
# including within the repo. The test images were downloaded manually from the
# different pages of https://developers.google.com/speed/webp/gallery
@pytest.mark.skipif(not WEBP_TEST_IMAGES_DIR, reason="WEBP_TEST_IMAGES_DIR is not set")
@pytest.mark.parametrize("decode_fun", (decode_webp, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode", ((ImageReadMode.RGB, "RGB"), (ImageReadMode.RGB_ALPHA, "RGBA"), (ImageReadMode.UNCHANGED, None))
"mode, pil_mode",
(
# Note that converting an RGBA image to RGB leads to bad results because the
# transparent pixels aren't necessarily set to "black" or "white", they can be
# random stuff. This is consistent with PIL results.
(ImageReadMode.RGB, "RGB"),
(ImageReadMode.RGB_ALPHA, "RGBA"),
(ImageReadMode.UNCHANGED, None),
),
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/webp_samples").glob("*.webp"))
@pytest.mark.parametrize("filename", Path(WEBP_TEST_IMAGES_DIR).glob("*.webp"), ids=lambda p: p.name)
def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename):
encoded_bytes = read_file(filename)
if scripted:
Expand All @@ -915,9 +927,10 @@ def test_decode_webp_against_pil(decode_fun, scripted, mode, pil_mode, filename)
pil_img = Image.open(filename).convert(pil_mode)
from_pil = F.pil_to_tensor(pil_img)
assert_equal(img, from_pil)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.skipif(not DECODE_AVIF_ENABLED, reason="AVIF support not enabled.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
Expand All @@ -927,12 +940,20 @@ def test_decode_avif(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.")
# Note: decode_image fails because some of these files have a (valid) signature
# we don't recognize. We should probably use libmagic....
@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic))
decode_funs = []
if DECODE_AVIF_ENABLED:
decode_funs.append(_decode_avif)
if DECODE_HEIC_ENABLED:
decode_funs.append(_decode_heic)


@pytest.mark.skipif(not decode_funs, reason="Built without avif and heic support.")
@pytest.mark.parametrize("decode_fun", decode_funs)
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode",
Expand All @@ -945,7 +966,7 @@ def test_decode_avif(decode_fun, scripted):
@pytest.mark.parametrize(
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
)
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
def test_decode_avif_heic_against_pil(decode_fun, scripted, mode, pil_mode, filename):
if "reversed_dimg_order" in str(filename):
# Pillow properly decodes this one, but we don't (order of parts of the
# image is wrong). This is due to a bug that was recently fixed in
Expand Down Expand Up @@ -996,21 +1017,21 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
g = make_grid([img, from_pil])
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))

is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
if mode == ImageReadMode.RGB and not is__decode_heic:
is_decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
if mode == ImageReadMode.RGB and not is_decode_heic:
# We don't compare torchvision's AVIF against PIL for RGB because
# results look pretty different on RGBA images (other images are fine).
# The result on torchvision basically just plainly ignores the alpha
# channel, resuting in transparent pixels looking dark. PIL seems to be
# using a sort of k-nn thing (Take a look at the resuting images)
return
if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic:
if filename.name == "sofa_grid1x5_420.avif" and is_decode_heic:
return

torch.testing.assert_close(img, from_pil, rtol=0, atol=3)


@pytest.mark.xfail(reason="HEIC support not enabled yet.")
@pytest.mark.skipif(not DECODE_HEIC_ENABLED, reason="HEIC support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_heic(decode_fun, scripted):
Expand All @@ -1020,6 +1041,7 @@ def test_decode_heic(decode_fun, scripted):
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


if __name__ == "__main__":
Expand Down
43 changes: 43 additions & 0 deletions torchvision/csrc/io/image/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

#include "common.h"
#include <torch/torch.h>

namespace vision {
namespace image {

void validate_encoded_data(const torch::Tensor& encoded_data) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1 && encoded_data.numel() > 0,
"Input tensor must be 1-dimensional and non-empty, got ",
encoded_data.dim(),
" dims and ",
encoded_data.numel(),
" numels.");
}

bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
ImageReadMode mode,
bool has_alpha) {
// Return true if the calling decoding function should return a 3D RGB tensor,
// and false if it should return a 4D RGBA tensor.
// This function ignores the requested "grayscale" modes and treats it as
// "unchanged", so it should only used on decoders who don't support grayscale
// outputs.

if (mode == IMAGE_READ_MODE_RGB) {
return true;
}
if (mode == IMAGE_READ_MODE_RGB_ALPHA) {
return false;
}
// From here we assume mode is "unchanged", even for grayscale ones.
return !has_alpha;
}

} // namespace image
} // namespace vision
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <stdint.h>
#include <torch/torch.h>

namespace vision {
namespace image {
Expand All @@ -13,5 +14,11 @@ const ImageReadMode IMAGE_READ_MODE_GRAY_ALPHA = 2;
const ImageReadMode IMAGE_READ_MODE_RGB = 3;
const ImageReadMode IMAGE_READ_MODE_RGB_ALPHA = 4;

void validate_encoded_data(const torch::Tensor& encoded_data);

bool should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
ImageReadMode mode,
bool has_alpha);

} // namespace image
} // namespace vision
26 changes: 5 additions & 21 deletions torchvision/csrc/io/image/cpu/decode_avif.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "decode_avif.h"
#include "../common.h"

#if AVIF_FOUND
#include "avif/avif.h"
Expand Down Expand Up @@ -33,16 +34,7 @@ torch::Tensor decode_avif(
// Refer there for more detail about what each function does, and which
// structure/data is available after which call.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");
validate_encoded_data(encoded_data);

DecoderPtr decoder(avifDecoderCreate());
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");
Expand All @@ -60,6 +52,7 @@ torch::Tensor decode_avif(
result == AVIF_RESULT_OK,
"avifDecoderParse failed: ",
avifResultToString(result));
printf("avif num images = %d\n", decoder->imageCount);
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");

Expand All @@ -78,18 +71,9 @@ torch::Tensor decode_avif(
auto use_uint8 = (decoder->image->depth <= 8);
rgb.depth = use_uint8 ? 8 : 16;

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !decoder->alphaPresent));
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
mode, decoder->alphaPresent);

auto num_channels = return_rgb ? 3 : 4;
rgb.format = return_rgb ? AVIF_RGB_FORMAT_RGB : AVIF_RGB_FORMAT_RGBA;
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/cpu/decode_avif.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"
#include "../common.h"

namespace vision {
namespace image {
Expand Down
12 changes: 2 additions & 10 deletions torchvision/csrc/io/image/cpu/decode_gif.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_gif.h"
#include <cstring>
#include "../common.h"
#include "giflib/gif_lib.h"

namespace vision {
Expand Down Expand Up @@ -34,16 +35,7 @@ torch::Tensor decode_gif(const torch::Tensor& encoded_data) {
// Refer over there for more details on the libgif API, API ref, and a
// detailed description of the GIF format.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");
validate_encoded_data(encoded_data);

int error = D_GIF_SUCCEEDED;

Expand Down
25 changes: 4 additions & 21 deletions torchvision/csrc/io/image/cpu/decode_heic.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "decode_heic.h"
#include "../common.h"

#if HEIC_FOUND
#include "libheif/heif_cxx.h"
Expand All @@ -19,26 +20,8 @@ torch::Tensor decode_heic(
torch::Tensor decode_heic(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}
validate_encoded_data(encoded_data);

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb = true;

int height = 0;
Expand Down Expand Up @@ -82,8 +65,8 @@ torch::Tensor decode_heic(
bit_depth = handle.get_luma_bits_per_pixel();

return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel()));
should_this_return_rgb_or_rgba_let_me_know_in_the_comments_down_below_guys_see_you_in_the_next_video(
mode, handle.has_alpha_channel());

height = handle.get_height();
width = handle.get_width();
Expand Down
Loading
Loading