Skip to content

Commit

Permalink
making Kornia transforms work with added temporal dimension.
Browse files Browse the repository at this point in the history
  • Loading branch information
keves1 committed Nov 21, 2024
1 parent 4513e84 commit 7e5ba82
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
8 changes: 5 additions & 3 deletions torchgeo/datamodules/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def __init__(
self.std = torch.tensor([STD[b] for b in self.bands])

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
data_keys=['image1', 'image2', 'mask'],
K.VideoSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
),
data_keys=['image', 'mask'],
)

def setup(self, stage: str) -> None:
Expand Down
6 changes: 4 additions & 2 deletions torchgeo/datasets/oscd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
image1 = self._load_image(files['images1'])
image2 = self._load_image(files['images2'])
mask = self._load_target(str(files['mask']))
sample = {'image1': image1, 'image2': image2, 'mask': mask}
image = torch.stack(tensors=[image1, image2], dim=0)
sample = {'image': image, 'mask': mask}

if self.transforms is not None:
sample = self.transforms(sample)
Expand All @@ -169,7 +170,8 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]:
regions = []
labels_root = os.path.join(
self.root,
f'Onera Satellite Change Detection dataset - {self.split.capitalize()} '
f'Onera Satellite Change Detection dataset - {
self.split.capitalize()} '
+ 'Labels',
)
images_root = os.path.join(
Expand Down
8 changes: 7 additions & 1 deletion torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ def forward(self, batch: dict[str, Any]) -> dict[str, Any]:
batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy')

# Torchmetrics does not support masks with a channel dimension
if 'mask' in batch and batch['mask'].shape[1] == 1:
# Kornia adds a temporal dimension to mask when passed through VideoSequential.
if 'mask' in batch and batch['mask'].ndim == 5:
if batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () c h w -> b c h w')
if batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
elif 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
if 'masks' in batch and batch['masks'].ndim == 4:
batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w')
Expand Down

0 comments on commit 7e5ba82

Please sign in to comment.