diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index f2e1d9ad06a..fb1f9a4d174 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -79,11 +79,13 @@ def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(grayscale_to_rgb, torch.Tensor) @_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image) def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor: + # rgb_to_grayscale can be used to add channels so we reuse that function. return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True) @_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image) def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: + # to_grayscale can expand channels from 1 to 3 so we reuse that function. return _FP.to_grayscale(image, num_output_channels=3)