diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index a7fd7bb23e..9471c16bb7 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -157,6 +157,11 @@ def _create_val_split(self) -> None: # converted from random training sample self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) + elif self.val_split_mode == ValSplitMode.FROM_TRAIN: + # randomly sampled from training set + self.train_data, self.val_data = random_split( + self.train_data, self.val_split_ratio, label_aware=True, seed=self.seed + ) elif self.val_split_mode != ValSplitMode.NONE: raise ValueError(f"Unknown validation split mode: {self.val_split_mode}") diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index e78d49ed5f..30659b2ea0 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -46,6 +46,7 @@ class ValSplitMode(str, Enum): NONE = "none" SAME_AS_TEST = "same_as_test" FROM_TEST = "from_test" + FROM_TRAIN = "from_train" SYNTHETIC = "synthetic" diff --git a/src/anomalib/models/efficient_ad/config.yaml b/src/anomalib/models/efficient_ad/config.yaml index ebe6c2be58..743cdb8618 100644 --- a/src/anomalib/models/efficient_ad/config.yaml +++ b/src/anomalib/models/efficient_ad/config.yaml @@ -15,7 +15,7 @@ dataset: eval: null test_split_mode: from_dir # options: [from_dir, synthetic] test_split_ratio: 0.2 # fraction of train images held out testing (usage depends on test_split_mode) - val_split_mode: same_as_test # options: [same_as_test, from_test, synthetic] + val_split_mode: from_train # options: [same_as_test, from_test, synthetic] val_split_ratio: 0.5 # fraction of train/test images held out for validation (usage depends on val_split_mode) model: diff --git a/src/anomalib/models/efficient_ad/lightning_model.py b/src/anomalib/models/efficient_ad/lightning_model.py index 09922ba7fc..fc3cbea151 100644 --- a/src/anomalib/models/efficient_ad/lightning_model.py +++ b/src/anomalib/models/efficient_ad/lightning_model.py @@ -59,10 +59,10 @@ class EfficientAd(AnomalyModule): model_size (str): size of student and teacher model lr (float): learning rate weight_decay (float): optimizer weight decay - padding (bool): use padding in convoluional layers - pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the - output anomaly maps so that their size matches the size in the padding = True case. + padding (bool): use padding in the convolutional layers of the student/teacher architecture batch_size (int): batch size for imagenet dataloader + pretraining_images_dir (str): path to folder with images used to pretrain the teacher model + and the code is calling it "imagenette", but it could be any dataset. """ def __init__( @@ -73,8 +73,8 @@ def __init__( lr: float = 0.0001, weight_decay: float = 0.00001, padding: bool = False, - pad_maps: bool = True, batch_size: int = 1, + pretraining_images_dir: str = "./datasets/imagenette", ) -> None: super().__init__() @@ -84,12 +84,12 @@ def __init__( input_size=image_size, model_size=model_size, padding=padding, - pad_maps=pad_maps, ) self.batch_size = batch_size self.image_size = image_size self.lr = lr self.weight_decay = weight_decay + self.pretraining_images_dir = pretraining_images_dir self.prepare_pretrained_model() self.prepare_imagenette_data() @@ -115,8 +115,9 @@ def prepare_imagenette_data(self) -> None: ] ) - imagenet_dir = Path("./datasets/imagenette") + imagenet_dir = Path(self.pretraining_images_dir) if not imagenet_dir.is_dir(): + raise FileNotFoundError(f"Imagenette dataset not found at {imagenet_dir}") download_and_extract(imagenet_dir, IMAGENETTE_DOWNLOAD_INFO) imagenet_dataset = ImageFolder(imagenet_dir, transform=TransformsWrapper(t=self.data_transforms_imagenet)) self.imagenet_loader = DataLoader(imagenet_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True) @@ -290,7 +291,6 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None: lr=hparams.model.lr, weight_decay=hparams.model.weight_decay, padding=hparams.model.padding, - pad_maps=hparams.model.pad_maps, image_size=hparams.dataset.image_size, batch_size=hparams.dataset.train_batch_size, ) diff --git a/src/anomalib/models/efficient_ad/torch_model.py b/src/anomalib/models/efficient_ad/torch_model.py index 950087785b..7fb2d6bd95 100644 --- a/src/anomalib/models/efficient_ad/torch_model.py +++ b/src/anomalib/models/efficient_ad/torch_model.py @@ -219,9 +219,7 @@ class EfficientAdModel(nn.Module): pretrained_models_dir (str): path to the pretrained model weights input_size (tuple): size of input images model_size (str): size of student and teacher model - padding (bool): use padding in convoluional layers - pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the - output anomaly maps so that their size matches the size in the padding = True case. + padding (bool): use padding in convoluional layers of the student/teacher architecture device (str): which device the model should be loaded on """ @@ -231,11 +229,10 @@ def __init__( input_size: tuple[int, int], model_size: EfficientAdModelSize = EfficientAdModelSize.S, padding=False, - pad_maps=True, ) -> None: super().__init__() - self.pad_maps = pad_maps + self.padding = padding self.teacher: PDN_M | PDN_S self.student: PDN_M | PDN_S @@ -340,9 +337,13 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict (ae_output - student_output[:, self.teacher_out_channels :]) ** 2, dim=1, keepdim=True ) - if self.pad_maps: + if not self.padding: + # when the teacher/student architecture does not use padding, the output anomaly maps + # are smaller than if they had it, so the maps are misaligned. To fix this, we pad + # score maps by 4 pixels on each side. See github.com/openvinotoolkit/anomalib/discussions/1368 map_st = F.pad(map_st, (4, 4, 4, 4)) map_stae = F.pad(map_stae, (4, 4, 4, 4)) + map_st = F.interpolate(map_st, size=(self.input_size[0], self.input_size[1]), mode="bilinear") map_stae = F.interpolate(map_stae, size=(self.input_size[0], self.input_size[1]), mode="bilinear")