Skip to content

Commit 6b56de1

Browse files
zy1gitNicolasHug
andauthored
Implement Flip transforms with CVCUDA backend (#9277)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent dccf466 commit 6b56de1

File tree

5 files changed

+143
-27
lines changed

5 files changed

+143
-27
lines changed

test/common_utils.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2121
from torchvision import io, tv_tensors
2222
from torchvision.transforms._functional_tensor import _max_value as get_max_value
23-
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
23+
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional._utils import _is_cvcuda_available, _is_cvcuda_tensor
2425
from torchvision.utils import _Image_fromarray
2526

2627

@@ -284,8 +285,24 @@ def __init__(
284285
mae=False,
285286
**other_parameters,
286287
):
287-
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
288-
actual, expected = (to_image(input) for input in [actual, expected])
288+
# Convert PIL images to tv_tensors.Image (regardless of what the other is)
289+
if isinstance(actual, PIL.Image.Image):
290+
actual = to_image(actual)
291+
if isinstance(expected, PIL.Image.Image):
292+
expected = to_image(expected)
293+
294+
if _is_cvcuda_available():
295+
if _is_cvcuda_tensor(actual):
296+
actual = cvcuda_to_tensor(actual)
297+
# Remove batch dimension if it's 1 for easier comparison against 3D PIL images
298+
if actual.shape[0] == 1:
299+
actual = actual[0]
300+
actual = actual.cpu()
301+
if _is_cvcuda_tensor(expected):
302+
expected = cvcuda_to_tensor(expected)
303+
if expected.shape[0] == 1:
304+
expected = expected[0]
305+
expected = expected.cpu()
289306

290307
super().__init__(actual, expected, **other_parameters)
291308
self.mae = mae
@@ -400,8 +417,8 @@ def make_image_pil(*args, **kwargs):
400417
return to_pil_image(make_image(*args, **kwargs))
401418

402419

403-
def make_image_cvcuda(*args, **kwargs):
404-
return to_cvcuda_tensor(make_image(*args, **kwargs))
420+
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
421+
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))
405422

406423

407424
def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
@@ -541,5 +558,9 @@ def ignore_jit_no_profile_information_warning():
541558
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
542559
# them.
543560
with warnings.catch_warnings():
544-
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
561+
warnings.filterwarnings(
562+
"ignore",
563+
message=re.escape("operator() profile_node %"),
564+
category=UserWarning,
565+
)
545566
yield

test/test_transforms_v2.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,10 @@ def test_kernel_video(self):
12401240
make_image_tensor,
12411241
make_image_pil,
12421242
make_image,
1243+
pytest.param(
1244+
make_image_cvcuda,
1245+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1246+
),
12431247
make_bounding_boxes,
12441248
make_segmentation_mask,
12451249
make_video,
@@ -1255,13 +1259,20 @@ def test_functional(self, make_input):
12551259
(F.horizontal_flip_image, torch.Tensor),
12561260
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
12571261
(F.horizontal_flip_image, tv_tensors.Image),
1262+
pytest.param(
1263+
F._geometry._horizontal_flip_image_cvcuda,
1264+
None,
1265+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1266+
),
12581267
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
12591268
(F.horizontal_flip_mask, tv_tensors.Mask),
12601269
(F.horizontal_flip_video, tv_tensors.Video),
12611270
(F.horizontal_flip_keypoints, tv_tensors.KeyPoints),
12621271
],
12631272
)
12641273
def test_functional_signature(self, kernel, input_type):
1274+
if kernel is F._geometry._horizontal_flip_image_cvcuda:
1275+
input_type = _import_cvcuda().Tensor
12651276
check_functional_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
12661277

12671278
@pytest.mark.parametrize(
@@ -1270,6 +1281,10 @@ def test_functional_signature(self, kernel, input_type):
12701281
make_image_tensor,
12711282
make_image_pil,
12721283
make_image,
1284+
pytest.param(
1285+
make_image_cvcuda,
1286+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1287+
),
12731288
make_bounding_boxes,
12741289
make_segmentation_mask,
12751290
make_video,
@@ -1283,13 +1298,23 @@ def test_transform(self, make_input, device):
12831298
@pytest.mark.parametrize(
12841299
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
12851300
)
1286-
def test_image_correctness(self, fn):
1287-
image = make_image(dtype=torch.uint8, device="cpu")
1288-
1301+
@pytest.mark.parametrize(
1302+
"make_input",
1303+
[
1304+
make_image,
1305+
pytest.param(
1306+
make_image_cvcuda,
1307+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1308+
),
1309+
],
1310+
)
1311+
def test_image_correctness(self, fn, make_input):
1312+
image = make_input()
12891313
actual = fn(image)
1290-
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))
1291-
1292-
torch.testing.assert_close(actual, expected)
1314+
if make_input is make_image_cvcuda:
1315+
image = F.cvcuda_to_tensor(image)[0].cpu()
1316+
expected = F.horizontal_flip(F.to_pil_image(image))
1317+
assert_equal(actual, expected)
12931318

12941319
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
12951320
affine_matrix = np.array(
@@ -1345,6 +1370,10 @@ def test_keypoints_correctness(self, fn):
13451370
make_image_tensor,
13461371
make_image_pil,
13471372
make_image,
1373+
pytest.param(
1374+
make_image_cvcuda,
1375+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1376+
),
13481377
make_bounding_boxes,
13491378
make_segmentation_mask,
13501379
make_video,
@@ -1354,11 +1383,8 @@ def test_keypoints_correctness(self, fn):
13541383
@pytest.mark.parametrize("device", cpu_and_cuda())
13551384
def test_transform_noop(self, make_input, device):
13561385
input = make_input(device=device)
1357-
13581386
transform = transforms.RandomHorizontalFlip(p=0)
1359-
13601387
output = transform(input)
1361-
13621388
assert_equal(output, input)
13631389

13641390

@@ -1856,6 +1882,10 @@ def test_kernel_video(self):
18561882
make_image_tensor,
18571883
make_image_pil,
18581884
make_image,
1885+
pytest.param(
1886+
make_image_cvcuda,
1887+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1888+
),
18591889
make_bounding_boxes,
18601890
make_segmentation_mask,
18611891
make_video,
@@ -1871,13 +1901,20 @@ def test_functional(self, make_input):
18711901
(F.vertical_flip_image, torch.Tensor),
18721902
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
18731903
(F.vertical_flip_image, tv_tensors.Image),
1904+
pytest.param(
1905+
F._geometry._vertical_flip_image_cvcuda,
1906+
None,
1907+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1908+
),
18741909
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
18751910
(F.vertical_flip_mask, tv_tensors.Mask),
18761911
(F.vertical_flip_video, tv_tensors.Video),
18771912
(F.vertical_flip_keypoints, tv_tensors.KeyPoints),
18781913
],
18791914
)
18801915
def test_functional_signature(self, kernel, input_type):
1916+
if kernel is F._geometry._vertical_flip_image_cvcuda:
1917+
input_type = _import_cvcuda().Tensor
18811918
check_functional_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
18821919

18831920
@pytest.mark.parametrize(
@@ -1886,6 +1923,10 @@ def test_functional_signature(self, kernel, input_type):
18861923
make_image_tensor,
18871924
make_image_pil,
18881925
make_image,
1926+
pytest.param(
1927+
make_image_cvcuda,
1928+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1929+
),
18891930
make_bounding_boxes,
18901931
make_segmentation_mask,
18911932
make_video,
@@ -1897,13 +1938,23 @@ def test_transform(self, make_input, device):
18971938
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))
18981939

18991940
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
1900-
def test_image_correctness(self, fn):
1901-
image = make_image(dtype=torch.uint8, device="cpu")
1902-
1941+
@pytest.mark.parametrize(
1942+
"make_input",
1943+
[
1944+
make_image,
1945+
pytest.param(
1946+
make_image_cvcuda,
1947+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
1948+
),
1949+
],
1950+
)
1951+
def test_image_correctness(self, fn, make_input):
1952+
image = make_input()
19031953
actual = fn(image)
1904-
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))
1905-
1906-
torch.testing.assert_close(actual, expected)
1954+
if make_input is make_image_cvcuda:
1955+
image = F.cvcuda_to_tensor(image)[0].cpu()
1956+
expected = F.vertical_flip(F.to_pil_image(image))
1957+
assert_equal(actual, expected)
19071958

19081959
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
19091960
affine_matrix = np.array(
@@ -1955,6 +2006,10 @@ def test_keypoints_correctness(self, fn):
19552006
make_image_tensor,
19562007
make_image_pil,
19572008
make_image,
2009+
pytest.param(
2010+
make_image_cvcuda,
2011+
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
2012+
),
19582013
make_bounding_boxes,
19592014
make_segmentation_mask,
19602015
make_video,
@@ -1964,11 +2019,8 @@ def test_keypoints_correctness(self, fn):
19642019
@pytest.mark.parametrize("device", cpu_and_cuda())
19652020
def test_transform_noop(self, make_input, device):
19662021
input = make_input(device=device)
1967-
19682022
transform = transforms.RandomVerticalFlip(p=0)
1969-
19702023
output = transform(input)
1971-
19722024
assert_equal(output, input)
19732025

19742026

torchvision/transforms/v2/_geometry.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchvision.ops.boxes import box_iou
1212
from torchvision.transforms.functional import _get_perspective_coeffs
1313
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
14-
from torchvision.transforms.v2.functional._utils import _FillType
14+
from torchvision.transforms.v2.functional._utils import _FillType, _is_cvcuda_available, _is_cvcuda_tensor
1515

1616
from ._transform import _RandomApplyTransform
1717
from ._utils import (
@@ -30,6 +30,8 @@
3030
query_size,
3131
)
3232

33+
CVCUDA_AVAILABLE = _is_cvcuda_available()
34+
3335

3436
class RandomHorizontalFlip(_RandomApplyTransform):
3537
"""Horizontally flip the input with a given probability.
@@ -45,6 +47,9 @@ class RandomHorizontalFlip(_RandomApplyTransform):
4547

4648
_v1_transform_cls = _transforms.RandomHorizontalFlip
4749

50+
if CVCUDA_AVAILABLE:
51+
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)
52+
4853
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
4954
return self._call_kernel(F.horizontal_flip, inpt)
5055

@@ -63,6 +68,9 @@ class RandomVerticalFlip(_RandomApplyTransform):
6368

6469
_v1_transform_cls = _transforms.RandomVerticalFlip
6570

71+
if CVCUDA_AVAILABLE:
72+
_transformed_types = _RandomApplyTransform._transformed_types + (_is_cvcuda_tensor,)
73+
6674
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
6775
return self._call_kernel(F.vertical_flip, inpt)
6876

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numbers
33
import warnings
44
from collections.abc import Sequence
5-
from typing import Any, Optional, Union
5+
from typing import Any, Optional, TYPE_CHECKING, Union
66

77
import PIL.Image
88
import torch
@@ -26,7 +26,18 @@
2626

2727
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
2828

29-
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
29+
from ._utils import (
30+
_FillTypeJIT,
31+
_get_kernel,
32+
_import_cvcuda,
33+
_is_cvcuda_available,
34+
_register_five_ten_crop_kernel_internal,
35+
_register_kernel_internal,
36+
)
37+
38+
CVCUDA_AVAILABLE = _is_cvcuda_available()
39+
if TYPE_CHECKING:
40+
import cvcuda # type: ignore[import-not-found]
3041

3142

3243
def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
@@ -62,6 +73,14 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
6273
return _FP.hflip(image)
6374

6475

76+
def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
77+
return _import_cvcuda().flip(image, flipCode=1)
78+
79+
80+
if CVCUDA_AVAILABLE:
81+
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(_horizontal_flip_image_cvcuda)
82+
83+
6584
@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
6685
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
6786
return horizontal_flip_image(mask)
@@ -150,6 +169,14 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
150169
return _FP.vflip(image)
151170

152171

172+
def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
173+
return _import_cvcuda().flip(image, flipCode=0)
174+
175+
176+
if CVCUDA_AVAILABLE:
177+
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(_vertical_flip_image_cvcuda)
178+
179+
153180
@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
154181
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
155182
return vertical_flip_image(mask)

torchvision/transforms/v2/functional/_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,11 @@ def _is_cvcuda_available():
169169
return True
170170
except ImportError:
171171
return False
172+
173+
174+
def _is_cvcuda_tensor(inpt: Any) -> bool:
175+
try:
176+
cvcuda = _import_cvcuda()
177+
return isinstance(inpt, cvcuda.Tensor)
178+
except ImportError:
179+
return False

0 commit comments

Comments
 (0)