diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index d27b2682055..157d4faaaea 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3022,12 +3022,18 @@ def test_errors(self): with pytest.raises(ValueError, match="Please provide only two dimensions"): transforms.RandomCrop([10, 12, 14]) - with pytest.raises(TypeError, match="Got inappropriate padding arg"): + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): transforms.RandomCrop([10, 12], padding="abc") with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7]) + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + transforms.RandomCrop([10, 12], padding=0.5) + + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + transforms.RandomCrop([10, 12], padding=[0.5, 0.5]) + with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomCrop([10, 12], padding=1, fill="abc") @@ -3892,12 +3898,18 @@ def test_transform(self, make_input): check_transform(transforms.Pad(padding=[1]), make_input()) def test_transform_errors(self): - with pytest.raises(TypeError, match="Got inappropriate padding arg"): + with pytest.raises(ValueError, match="Padding must be"): transforms.Pad("abc") - with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"): transforms.Pad([-0.7, 0, 0.7]) + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"): + transforms.Pad(0.5) + + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"): + transforms.Pad(padding=[0.5, 0.5]) + with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.Pad(12, fill="abc") diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index dd65ca4d9c9..92d5bc1a2ca 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -81,11 +81,13 @@ def _get_fill(fill_dict, inpt_type): def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: - raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + err_msg = f"Padding must be an int or a 1, 2, or 4 element of tuple or list, got {padding}." + if isinstance(padding, (tuple, list)): + if len(padding) not in [1, 2, 4] or not all(isinstance(p, int) for p in padding): + raise ValueError(err_msg) + elif not isinstance(padding, int): + raise ValueError(err_msg) # TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)