diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 6b52cc31b6d..6e2ec2565e1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -10,6 +10,7 @@ from PIL import Image from PIL.Image import Image as PILImage from torch import Tensor +from numpy.typing import NDArray try: import accimage @@ -124,7 +125,7 @@ def _is_numpy_image(img: Any) -> bool: return img.ndim in {2, 3} -def to_tensor(pic: Union[PILImage, np.ndarray[int, int]]) -> Tensor: +def to_tensor(pic: Union[PILImage, NDArray[np.uint8]]) -> Tensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This function does not support torchscript.