Skip to content

Commit

Permalink
transforms: Switch to kornia AugmentationSequential (#2008)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 authored Jun 30, 2024
1 parent c8e1e09 commit 87a5da2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
15 changes: 9 additions & 6 deletions tests/transforms/test_color.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import kornia.augmentation as K
import pytest
import torch
from torch import Tensor

from torchgeo.transforms import AugmentationSequential, RandomGrayscale
from torchgeo.transforms import RandomGrayscale


@pytest.fixture
Expand Down Expand Up @@ -33,12 +34,15 @@ def batch() -> dict[str, Tensor]:
],
)
def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None:
aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image'])
aug = K.AugmentationSequential(
RandomGrayscale(weights, p=1), keepdim=True, data_keys=None
)
# https://github.com/kornia/kornia/issues/2848
aug.keepdim = True
output = aug(sample)
assert output['image'].shape == sample['image'].shape
assert output['image'].sum() == sample['image'].sum()
for i in range(1, 3):
assert torch.allclose(output['image'][0, 0], output['image'][0, i])
assert torch.allclose(output['image'][0], output['image'][i])


@pytest.mark.parametrize(
Expand All @@ -50,9 +54,8 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) ->
],
)
def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None:
aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image'])
aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None)
output = aug(batch)
assert output['image'].shape == batch['image'].shape
assert output['image'].sum() == batch['image'].sum()
for i in range(1, 3):
assert torch.allclose(output['image'][0, 0], output['image'][0, i])
20 changes: 9 additions & 11 deletions tests/transforms/test_indices.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import kornia.augmentation as K
import pytest
import torch
from torch import Tensor
Expand All @@ -20,7 +21,6 @@
AppendRBNDVI,
AppendSWI,
AppendTriBandNormalizedDifferenceIndex,
AugmentationSequential,
)


Expand All @@ -42,29 +42,27 @@ def batch() -> dict[str, Tensor]:

def test_append_index_sample(sample: dict[str, Tensor]) -> None:
c, h, w = sample['image'].shape
aug = AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1),
data_keys=['image', 'mask'],
aug = K.AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
)
output = aug(sample)
assert output['image'].shape == (1, c + 1, h, w)


def test_append_index_batch(batch: dict[str, Tensor]) -> None:
b, c, h, w = batch['image'].shape
aug = AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1),
data_keys=['image', 'mask'],
aug = K.AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
)
output = aug(batch)
assert output['image'].shape == (b, c + 1, h, w)


def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None:
b, c, h, w = batch['image'].shape
aug = AugmentationSequential(
aug = K.AugmentationSequential(
AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2),
data_keys=['image', 'mask'],
data_keys=None,
)
output = aug(batch)
assert output['image'].shape == (b, c + 1, h, w)
Expand All @@ -88,7 +86,7 @@ def test_append_normalized_difference_indices(
sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex
) -> None:
c, h, w = sample['image'].shape
aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask'])
aug = K.AugmentationSequential(index(0, 1), data_keys=None)
output = aug(sample)
assert output['image'].shape == (1, c + 1, h, w)

Expand All @@ -98,6 +96,6 @@ def test_append_tri_band_normalized_difference_indices(
sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex
) -> None:
c, h, w = sample['image'].shape
aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask'])
aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None)
output = aug(sample)
assert output['image'].shape == (1, c + 1, h, w)

0 comments on commit 87a5da2

Please sign in to comment.