diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 87ae20cdf40..fea9602adeb 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -86,9 +86,11 @@ def __init__( self.std = torch.tensor([STD[b] for b in self.bands]) self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + K.VideoSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + ), + data_keys=['image', 'mask'], ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 28f7714a7c6..8a6616534eb 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -150,7 +150,8 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: image1 = self._load_image(files['images1']) image2 = self._load_image(files['images2']) mask = self._load_target(str(files['mask'])) - sample = {'image1': image1, 'image2': image2, 'mask': mask} + image = torch.stack(tensors=[image1, image2], dim=0) + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -169,7 +170,8 @@ def _load_files(self) -> list[dict[str, str | Sequence[str]]]: regions = [] labels_root = os.path.join( self.root, - f'Onera Satellite Change Detection dataset - {self.split.capitalize()} ' + f'Onera Satellite Change Detection dataset - { + self.split.capitalize()} ' + 'Labels', ) images_root = os.path.join( diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8f80bdcaac..e04389d8b18 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -102,7 +102,13 @@ def forward(self, batch: dict[str, Any]) -> dict[str, Any]: batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # Torchmetrics does not support masks with a channel dimension - if 'mask' in batch and batch['mask'].shape[1] == 1: + # Kornia adds a temporal dimension to mask when passed through VideoSequential. + if 'mask' in batch and batch['mask'].ndim == 5: + if batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () c h w -> b c h w') + if batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + elif 'mask' in batch and batch['mask'].shape[1] == 1: batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') if 'masks' in batch and batch['masks'].ndim == 4: batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w')