diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 2818698bce8..c60fbc51c78 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -18,13 +18,11 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT -def _setup_number_or_seq( - arg: Union[int, float, Sequence[Union[int, float]]], name: str, req_size: int = 2 -) -> Sequence[float]: +def _setup_number_or_seq(arg: Union[int, float, Sequence[Union[int, float]]], name: str) -> Sequence[float]: if not isinstance(arg, (int, float, Sequence)): raise TypeError(f"{name} should be a number or a sequence of numbers. Got {type(arg)}") if isinstance(arg, Sequence) and len(arg) not in (1, req_size): - raise ValueError(f"If {name} is a sequence its length should be 1 or {req_size}. Got {len(arg)}") + raise ValueError(f"If {name} is a sequence its length should be 1 or 2. Got {len(arg)}") if isinstance(arg, Sequence): for element in arg: if not isinstance(element, (int, float)): @@ -32,7 +30,7 @@ def _setup_number_or_seq( if isinstance(arg, (int, float)): arg = [float(arg), float(arg)] - if isinstance(arg, (list, tuple)): + if isinstance(arg, Sequence): if len(arg) == 1: arg = [float(arg[0]), float(arg[0])] else: