Skip to content

Commit

Permalink
Let to_tensor return torch.uint16 tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Apr 3, 2024
1 parent 5181a85 commit d3d2f79
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ the tensor dtype. Tensor images with a float dtype are expected to have
values in ``[0, 1]``. Tensor images with an integer dtype are expected to
have values in ``[0, MAX_DTYPE]`` where ``MAX_DTYPE`` is the largest value
that can be represented in that dtype. Typically, images of dtype
``torch.uint8`` are expected to have values in ``[0, 255]``.
``torch.uint8`` are expected to have values in ``[0, 255]``. Note that dtypes
like ``torch.uint16`` or ``torch.uint32`` aren't fully supported.

Use :class:`~torchvision.transforms.v2.ToDtype` to convert both the dtype and
range of the inputs.
Expand Down
12 changes: 12 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5327,6 +5327,18 @@ def test_functional_error(self):
F.pil_to_tensor(object())


@pytest.mark.parametrize("f", [F.to_tensor, F.pil_to_tensor])
def test_I16_to_tensor(f):
# See https://github.com/pytorch/vision/issues/8359
I16_pil_img = PIL.Image.fromarray(np.random.randint(0, 2 ** 16, (10, 10), dtype=np.uint16))
assert I16_pil_img.mode == "I;16"

cm = pytest.warns(UserWarning, match="deprecated") if f is F.to_tensor else contextlib.nullcontext()
with cm:
out = f(I16_pil_img)
assert out.dtype == torch.uint16


class TestLambda:
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor:
return torch.from_numpy(nppic).to(dtype=default_float_dtype)

# handle PIL Image
mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.int16, "F": np.float32}
mode_to_nptype = {"I": np.int32, "I;16" if sys.byteorder == "little" else "I;16B": np.uint16, "F": np.float32}
img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True))

if pic.mode == "1":
Expand Down

0 comments on commit d3d2f79

Please sign in to comment.