Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] PadToSquare: Square Padding to Preserve Aspect Ratios When Resizing Images with Varied Shapes in torchvision.transforms.v2 #8701

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ Others
v2.RandomHorizontalFlip
v2.RandomVerticalFlip
v2.Pad
v2.PadToSquare
v2.RandomZoomOut
v2.RandomRotation
v2.RandomAffine
Expand Down
38 changes: 38 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3928,6 +3928,44 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
assert_equal(actual, expected)


class TestPadToSquare:
@pytest.mark.parametrize(
"image",
[
(make_image((3, 10), device="cpu", dtype=torch.uint8)),
(make_image((10, 3), device="cpu", dtype=torch.uint8)),
(make_image((10, 10), device="cpu", dtype=torch.uint8)),
],
)
def test__get_params(self, image):
transform = transforms.PadToSquare()
params = transform._get_params([image])

assert "padding" in params
padding = params["padding"]

assert len(padding) == 4
assert all(p >= 0 for p in padding)

height, width = F.get_size(image)
assert max(height, width) == height + padding[1] + padding[3]
assert max(height, width) == width + padding[0] + padding[2]

@pytest.mark.parametrize(
"image, expected_output_shape",
[
(make_image((3, 10), device="cpu", dtype=torch.uint8), [10, 10]),
(make_image((10, 3), device="cpu", dtype=torch.uint8), [10, 10]),
(make_image((10, 10), device="cpu", dtype=torch.uint8), [10, 10]),
],
)
def test_pad_square_correctness(self, image, expected_output_shape):
transform = transforms.PadToSquare()
output = transform(image)

assert F.get_size(output) == expected_output_shape


class TestCenterCrop:
INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [(3, 5), (5, 3), (4, 4), (21, 9), (13, 15), (19, 14), 3, (4,), [5], INPUT_SIZE]
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ElasticTransform,
FiveCrop,
Pad,
PadToSquare,
RandomAffine,
RandomCrop,
RandomHorizontalFlip,
Expand Down
78 changes: 78 additions & 0 deletions torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,84 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]


class PadToSquare(Transform):
"""Pad a non-square input to make it square by padding the shorter side to match the longer side.

Args:
fill (number or tuple or dict, optional): Pixel fill value used when the ``padding_mode`` is constant.
Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively.
Fill value can be also a dictionary mapping data type to the fill value, e.g.
``fill={tv_tensors.Image: 127, tv_tensors.Mask: 0}`` where ``Image`` will be filled with 127 and
``Mask`` will be filled with 0.
padding_mode (str, optional): Type of padding. Should be: constant, edge, reflect or symmetric.
Default is "constant".

- constant: pads with a constant value, this value is specified with fill

- edge: pads with the last value at the edge of the image.

- reflect: pads with reflection of image without repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]

- symmetric: pads with reflection of image repeating the last value on the edge.
For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]

Example:
>>> import torch
>>> from torchvision.transforms.v2 import PadToSquare
>>> rectangular_image = torch.randint(0, 255, (3, 224, 168), dtype=torch.uint8)
>>> transform = PadToSquare(padding_mode='constant', fill=0)
>>> square_image = transform(rectangular_image)
>>> print(square_image.size())
torch.Size([3, 224, 224])
"""

def __init__(
self,
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = 0,
padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant",
):
super().__init__()

_check_padding_mode_arg(padding_mode)

if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("`padding_mode` must be one of 'constant', 'edge', 'reflect' or 'symmetric'.")
self.padding_mode = padding_mode
self.fill = _setup_fill_arg(fill)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
# Get the original height and width from the inputs
orig_height, orig_width = query_size(flat_inputs)

# Find the target size (maximum of height and width)
target_size = max(orig_height, orig_width)

if orig_height < target_size:
# Need to pad height
pad_height = target_size - orig_height
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = 0
pad_right = 0
else:
# Need to pad width
pad_width = target_size - orig_width
pad_left = pad_width // 2
pad_right = pad_width - pad_left
pad_top = 0
pad_bottom = 0

# The padding needs to be in the format [left, top, right, bottom]
return dict(padding=[pad_left, pad_top, pad_right, pad_bottom])

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
fill = _get_fill(self.fill, type(inpt))
return self._call_kernel(F.pad, inpt, padding=params["padding"], padding_mode=self.padding_mode, fill=fill)


class RandomZoomOut(_RandomApplyTransform):
""" "Zoom out" transformation from
`"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
Expand Down