diff --git a/pyproject.toml b/pyproject.toml index f01e4933f4a..254e765da24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,9 @@ dependencies = [ "torchmetrics>=0.10", # torchvision 0.14+ required for torchvision.models.swin_v2_b "torchvision>=0.14", + # typing-extensions 4.5+ required for typing_extensions.deprecated + # can be removed once Python 3.13 is minimum supported version + "typing-extensions>=4.5", ] dynamic = ["version"] diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index c5a5f6dae58..263794d0d46 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -20,6 +20,7 @@ timm==0.4.12 torch==1.13.0 torchmetrics==0.10.0 torchvision==0.14.0 +typing-extensions==4.5.0 # datasets h5py==3.6.0 diff --git a/requirements/required.txt b/requirements/required.txt index f6db590815f..7e9eda0bcc1 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -20,3 +20,4 @@ timm==0.9.7 torch==2.5.1 torchmetrics==1.4.3 torchvision==0.20.1 +typing-extensions==4.12.2 diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 1f2071ae812..8218434b30f 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -82,9 +82,10 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + with pytest.deprecated_call(): + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] + ) output = augs(batch_gray) assert_matching(output, expected) @@ -105,9 +106,10 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + with pytest.deprecated_call(): + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] + ) output = augs(batch_rgb) assert_matching(output, expected) @@ -132,9 +134,10 @@ def test_augmentation_sequential_multispectral( 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] - ) + with pytest.deprecated_call(): + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] + ) output = augs(batch_multispectral) assert_matching(output, expected) @@ -159,9 +162,10 @@ def test_augmentation_sequential_image_only( 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = transforms.AugmentationSequential( - K.RandomHorizontalFlip(p=1.0), data_keys=['image'] - ) + with pytest.deprecated_call(): + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image'] + ) output = augs(batch_multispectral) assert_matching(output, expected) @@ -191,15 +195,16 @@ def test_sequential_transforms_augmentations( 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - train_transforms = transforms.AugmentationSequential( - indices.AppendNBR(index_nir=0, index_swir=0), - indices.AppendNDBI(index_swir=0, index_nir=0), - indices.AppendNDSI(index_green=0, index_swir=0), - indices.AppendNDVI(index_red=0, index_nir=0), - indices.AppendNDWI(index_green=0, index_nir=0), - K.RandomHorizontalFlip(p=1.0), - data_keys=['image', 'mask', 'boxes'], - ) + with pytest.deprecated_call(): + train_transforms = transforms.AugmentationSequential( + indices.AppendNBR(index_nir=0, index_swir=0), + indices.AppendNDBI(index_swir=0, index_nir=0), + indices.AppendNDSI(index_green=0, index_swir=0), + indices.AppendNDVI(index_red=0, index_nir=0), + indices.AppendNDWI(index_green=0, index_nir=0), + K.RandomHorizontalFlip(p=1.0), + data_keys=['image', 'mask', 'boxes'], + ) output = train_transforms(batch_multispectral) assert_matching(output, expected) @@ -215,9 +220,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask'] - ) + with pytest.deprecated_call(): + train_transforms = transforms.AugmentationSequential( + _ExtractPatches(window_size=p), + same_on_batch=True, + data_keys=['image', 'mask'], + ) output = train_transforms(batch) assert batch['image'].shape == (b * num_patches, c, p, p) assert batch['mask'].shape == (b * num_patches, p, p) @@ -229,11 +237,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p, stride=s), - same_on_batch=True, - data_keys=['image', 'mask'], - ) + with pytest.deprecated_call(): + train_transforms = transforms.AugmentationSequential( + _ExtractPatches(window_size=p, stride=s), + same_on_batch=True, + data_keys=['image', 'mask'], + ) output = train_transforms(batch) assert batch['image'].shape == (b * num_patches, c, p, p) assert batch['mask'].shape == (b * num_patches, p, p) @@ -245,11 +254,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = transforms.AugmentationSequential( - _ExtractPatches(window_size=p, stride=s, keepdim=False), - same_on_batch=True, - data_keys=['image', 'mask'], - ) + with pytest.deprecated_call(): + train_transforms = transforms.AugmentationSequential( + _ExtractPatches(window_size=p, stride=s, keepdim=False), + same_on_batch=True, + data_keys=['image', 'mask'], + ) output = train_transforms(batch) for k, v in output.items(): print(k, v.shape, v.dtype) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..6e2c8b007e3 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -13,9 +13,11 @@ from kornia.geometry.boxes import Boxes from torch import Tensor from torch.nn.modules import Module +from typing_extensions import deprecated # TODO: contribute these to Kornia and delete this file +@deprecated('Use kornia.augmentation.AugmentationSequential instead') class AugmentationSequential(Module): """Wrapper around kornia AugmentationSequential to handle input dicts.