From 7a206dfaaf62641b1d47f03cc4ec48f540ab8608 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Tue, 26 Sep 2023 20:20:51 +0200 Subject: [PATCH 1/5] enable efficientnet ad option for pretraining image dir --- src/anomalib/models/efficient_ad/lightning_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index a35a65f963..1870b70aaf 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -63,6 +63,9 @@ class EfficientAd(AnomalyModule): pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the output anomaly maps so that their size matches the size in the padding = True case. batch_size (int): batch size for imagenet dataloader + pretraining_images_dir (str): path to folder with images used to pretrain the teacher model + TODO note in PR: the vocabulary is not consistent with the paper, where they call it "pretraining dataset" + and the code is calling it "imagenette", but it could be any dataset. """ def __init__( @@ -75,6 +78,7 @@ def __init__( padding: bool = False, pad_maps: bool = True, batch_size: int = 1, + pretraining_images_dir: str = "./datasets/imagenette", ) -> None: super().__init__() @@ -90,6 +94,7 @@ def __init__( self.image_size = image_size self.lr = lr self.weight_decay = weight_decay + self.pretraining_images_dir = pretraining_images_dir self.prepare_pretrained_model() self.prepare_imagenette_data() @@ -115,8 +120,9 @@ def prepare_imagenette_data(self) -> None: ] ) - imagenet_dir = Path("./datasets/imagenette") + imagenet_dir = Path(self.pretraining_images_dir) if not imagenet_dir.is_dir(): + raise FileNotFoundError(f"Imagenette dataset not found at {imagenet_dir}") download_and_extract(imagenet_dir, IMAGENETTE_DOWNLOAD_INFO) imagenet_dataset = ImageFolder(imagenet_dir, transform=TransformsWrapper(t=self.data_transforms_imagenet)) self.imagenet_loader = DataLoader(imagenet_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True) From d28b7832938a90ce57727c2da3f990457ef0f1eb Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:17:09 +0200 Subject: [PATCH 2/5] add val split mode from_train --- src/anomalib/data/base/datamodule.py | 5 +++++ src/anomalib/data/utils/split.py | 1 + src/anomalib/models/efficient_ad/config.yaml | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index a7fd7bb23e..9471c16bb7 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -157,6 +157,11 @@ def _create_val_split(self) -> None: # converted from random training sample self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) + elif self.val_split_mode == ValSplitMode.FROM_TRAIN: + # randomly sampled from training set + self.train_data, self.val_data = random_split( + self.train_data, self.val_split_ratio, label_aware=True, seed=self.seed + ) elif self.val_split_mode != ValSplitMode.NONE: raise ValueError(f"Unknown validation split mode: {self.val_split_mode}") diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index e78d49ed5f..30659b2ea0 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -46,6 +46,7 @@ class ValSplitMode(str, Enum): NONE = "none" SAME_AS_TEST = "same_as_test" FROM_TEST = "from_test" + FROM_TRAIN = "from_train" SYNTHETIC = "synthetic" diff --git a/src/anomalib/models/efficient_ad/config.yaml b/src/anomalib/models/efficient_ad/config.yaml index ebe6c2be58..743cdb8618 100644 --- a/src/anomalib/models/efficient_ad/config.yaml +++ b/src/anomalib/models/efficient_ad/config.yaml @@ -15,7 +15,7 @@ dataset: eval: null test_split_mode: from_dir # options: [from_dir, synthetic] test_split_ratio: 0.2 # fraction of train images held out testing (usage depends on test_split_mode) - val_split_mode: same_as_test # options: [same_as_test, from_test, synthetic] + val_split_mode: from_train # options: [same_as_test, from_test, synthetic] val_split_ratio: 0.5 # fraction of train/test images held out for validation (usage depends on val_split_mode) model: From aa1ab9986cf85063885b5754d2f9192a0c8fd6c2 Mon Sep 17 00:00:00 2001 From: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Wed, 27 Sep 2023 22:07:58 +0200 Subject: [PATCH 3/5] Update src/anomalib/models/efficient_ad/lightning_model.py --- src/anomalib/models/efficient_ad/lightning_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 1870b70aaf..6bbff73ef0 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -64,7 +64,6 @@ class EfficientAd(AnomalyModule): output anomaly maps so that their size matches the size in the padding = True case. batch_size (int): batch size for imagenet dataloader pretraining_images_dir (str): path to folder with images used to pretrain the teacher model - TODO note in PR: the vocabulary is not consistent with the paper, where they call it "pretraining dataset" and the code is calling it "imagenette", but it could be any dataset. """ From 13d959a0461662d11a72d9cf4b81ba43f6faba73 Mon Sep 17 00:00:00 2001 From: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Fri, 29 Sep 2023 12:40:56 +0200 Subject: [PATCH 4/5] make padding on inference maps automatic when padding=False --- src/anomalib/models/efficient_ad/lightning_model.py | 7 +------ src/anomalib/models/efficient_ad/torch_model.py | 13 +++++++------ 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 6bbff73ef0..da7fa16bb9 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -59,9 +59,7 @@ class EfficientAd(AnomalyModule): model_size (str): size of student and teacher model lr (float): learning rate weight_decay (float): optimizer weight decay - padding (bool): use padding in convoluional layers - pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the - output anomaly maps so that their size matches the size in the padding = True case. + padding (bool): use padding in convoluional layers of the student/teacher architecture batch_size (int): batch size for imagenet dataloader pretraining_images_dir (str): path to folder with images used to pretrain the teacher model and the code is calling it "imagenette", but it could be any dataset. @@ -75,7 +73,6 @@ def __init__( lr: float = 0.0001, weight_decay: float = 0.00001, padding: bool = False, - pad_maps: bool = True, batch_size: int = 1, pretraining_images_dir: str = "./datasets/imagenette", ) -> None: @@ -87,7 +84,6 @@ def __init__( input_size=image_size, model_size=model_size, padding=padding, - pad_maps=pad_maps, ) self.batch_size = batch_size self.image_size = image_size @@ -285,7 +281,6 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None: lr=hparams.model.lr, weight_decay=hparams.model.weight_decay, padding=hparams.model.padding, - pad_maps=hparams.model.pad_maps, image_size=hparams.dataset.image_size, batch_size=hparams.dataset.train_batch_size, ) diff --git a/src/anomalib/models/efficient_ad/torch_model.py b/src/anomalib/models/efficient_ad/torch_model.py index 950087785b..7fb2d6bd95 100644 --- a/src/anomalib/models/efficient_ad/torch_model.py +++ b/src/anomalib/models/efficient_ad/torch_model.py @@ -219,9 +219,7 @@ class EfficientAdModel(nn.Module): pretrained_models_dir (str): path to the pretrained model weights input_size (tuple): size of input images model_size (str): size of student and teacher model - padding (bool): use padding in convoluional layers - pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the - output anomaly maps so that their size matches the size in the padding = True case. + padding (bool): use padding in convoluional layers of the student/teacher architecture device (str): which device the model should be loaded on """ @@ -231,11 +229,10 @@ def __init__( input_size: tuple[int, int], model_size: EfficientAdModelSize = EfficientAdModelSize.S, padding=False, - pad_maps=True, ) -> None: super().__init__() - self.pad_maps = pad_maps + self.padding = padding self.teacher: PDN_M | PDN_S self.student: PDN_M | PDN_S @@ -340,9 +337,13 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict (ae_output - student_output[:, self.teacher_out_channels :]) ** 2, dim=1, keepdim=True ) - if self.pad_maps: + if not self.padding: + # when the teacher/student architecture does not use padding, the output anomaly maps + # are smaller than if they had it, so the maps are misaligned. To fix this, we pad + # score maps by 4 pixels on each side. See github.com/openvinotoolkit/anomalib/discussions/1368 map_st = F.pad(map_st, (4, 4, 4, 4)) map_stae = F.pad(map_stae, (4, 4, 4, 4)) + map_st = F.interpolate(map_st, size=(self.input_size[0], self.input_size[1]), mode="bilinear") map_stae = F.interpolate(map_stae, size=(self.input_size[0], self.input_size[1]), mode="bilinear") From 312fe6ffac773af31d0125fe4da8d07de505b8d9 Mon Sep 17 00:00:00 2001 From: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com> Date: Tue, 19 Dec 2023 18:16:45 +0100 Subject: [PATCH 5/5] Update src/anomalib/models/efficient_ad/lightning_model.py Co-authored-by: Sean Aubin --- src/anomalib/models/efficient_ad/lightning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 7887362cec..fc3cbea151 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -59,7 +59,7 @@ class EfficientAd(AnomalyModule): model_size (str): size of student and teacher model lr (float): learning rate weight_decay (float): optimizer weight decay - padding (bool): use padding in convoluional layers of the student/teacher architecture + padding (bool): use padding in the convolutional layers of the student/teacher architecture batch_size (int): batch size for imagenet dataloader pretraining_images_dir (str): path to folder with images used to pretrain the teacher model and the code is calling it "imagenette", but it could be any dataset.