From d7490146c0605b59e9d177b26fefbaa5f9142916 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Wed, 6 Nov 2024 21:59:52 +0400 Subject: [PATCH] Remove explicit keepdim --- torchgeo/datamodules/nasa_marine_debris.py | 2 -- torchgeo/datamodules/vhr10.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index add6b32b063..43511e9d5f8 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -48,8 +48,6 @@ def __init__( self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] self.collate_fn = collate_fn_detection diff --git a/torchgeo/datamodules/vhr10.py b/torchgeo/datamodules/vhr10.py index 8e9d56c7f7a..db4f3437667 100644 --- a/torchgeo/datamodules/vhr10.py +++ b/torchgeo/datamodules/vhr10.py @@ -59,13 +59,10 @@ def __init__( data_keys=None, keepdim=True, ) - self.train_aug.keepdim = True # type: ignore[attr-defined] self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - # https://github.com/kornia/kornia/issues/2848 - self.aug.keepdim = True # type: ignore[attr-defined] def setup(self, stage: str) -> None: """Set up datasets. @@ -76,7 +73,6 @@ def setup(self, stage: str) -> None: self.kwargs['transforms'] = K.AugmentationSequential( K.Resize(self.patch_size), data_keys=None, keepdim=True ) - self.kwargs['transforms'].keepdim = True self.dataset = VHR10(**self.kwargs) generator = torch.Generator().manual_seed(0) self.train_dataset, self.val_dataset, self.test_dataset = random_split(