Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/1370 efficientad validation and pretraining images #1376

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def _create_val_split(self) -> None:
# converted from random training sample
self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed)
self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)
elif self.val_split_mode == ValSplitMode.FROM_TRAIN:
# randomly sampled from training set
self.train_data, self.val_data = random_split(
self.train_data, self.val_split_ratio, label_aware=True, seed=self.seed
)
elif self.val_split_mode != ValSplitMode.NONE:
raise ValueError(f"Unknown validation split mode: {self.val_split_mode}")

Expand Down
1 change: 1 addition & 0 deletions src/anomalib/data/utils/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ValSplitMode(str, Enum):
NONE = "none"
SAME_AS_TEST = "same_as_test"
FROM_TEST = "from_test"
FROM_TRAIN = "from_train"
SYNTHETIC = "synthetic"


Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/efficient_ad/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dataset:
eval: null
test_split_mode: from_dir # options: [from_dir, synthetic]
test_split_ratio: 0.2 # fraction of train images held out testing (usage depends on test_split_mode)
val_split_mode: same_as_test # options: [same_as_test, from_test, synthetic]
val_split_mode: from_train # options: [same_as_test, from_test, synthetic]
val_split_ratio: 0.5 # fraction of train/test images held out for validation (usage depends on val_split_mode)

model:
Expand Down
14 changes: 7 additions & 7 deletions src/anomalib/models/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class EfficientAd(AnomalyModule):
model_size (str): size of student and teacher model
lr (float): learning rate
weight_decay (float): optimizer weight decay
padding (bool): use padding in convoluional layers
pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the
output anomaly maps so that their size matches the size in the padding = True case.
padding (bool): use padding in convoluional layers of the student/teacher architecture
jpcbertoldo marked this conversation as resolved.
Show resolved Hide resolved
batch_size (int): batch size for imagenet dataloader
pretraining_images_dir (str): path to folder with images used to pretrain the teacher model
and the code is calling it "imagenette", but it could be any dataset.
"""

def __init__(
Expand All @@ -73,8 +73,8 @@ def __init__(
lr: float = 0.0001,
weight_decay: float = 0.00001,
padding: bool = False,
pad_maps: bool = True,
batch_size: int = 1,
pretraining_images_dir: str = "./datasets/imagenette",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to make this parameter optional? That way, if it is None you can still do the automatic download of the Imagenet dataset.

) -> None:
super().__init__()

Expand All @@ -84,12 +84,12 @@ def __init__(
input_size=image_size,
model_size=model_size,
padding=padding,
pad_maps=pad_maps,
)
self.batch_size = batch_size
self.image_size = image_size
self.lr = lr
self.weight_decay = weight_decay
self.pretraining_images_dir = pretraining_images_dir

self.prepare_pretrained_model()
self.prepare_imagenette_data()
Expand All @@ -115,8 +115,9 @@ def prepare_imagenette_data(self) -> None:
]
)

imagenet_dir = Path("./datasets/imagenette")
imagenet_dir = Path(self.pretraining_images_dir)
if not imagenet_dir.is_dir():
raise FileNotFoundError(f"Imagenette dataset not found at {imagenet_dir}")
download_and_extract(imagenet_dir, IMAGENETTE_DOWNLOAD_INFO)
imagenet_dataset = ImageFolder(imagenet_dir, transform=TransformsWrapper(t=self.data_transforms_imagenet))
self.imagenet_loader = DataLoader(imagenet_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True)
Expand Down Expand Up @@ -280,7 +281,6 @@ def __init__(self, hparams: DictConfig | ListConfig) -> None:
lr=hparams.model.lr,
weight_decay=hparams.model.weight_decay,
padding=hparams.model.padding,
pad_maps=hparams.model.pad_maps,
image_size=hparams.dataset.image_size,
batch_size=hparams.dataset.train_batch_size,
)
Expand Down
13 changes: 7 additions & 6 deletions src/anomalib/models/efficient_ad/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,7 @@ class EfficientAdModel(nn.Module):
pretrained_models_dir (str): path to the pretrained model weights
input_size (tuple): size of input images
model_size (str): size of student and teacher model
padding (bool): use padding in convoluional layers
pad_maps (bool): relevant if padding is set to False. In this case, pad_maps = True pads the
output anomaly maps so that their size matches the size in the padding = True case.
padding (bool): use padding in convoluional layers of the student/teacher architecture
device (str): which device the model should be loaded on
"""

Expand All @@ -231,11 +229,10 @@ def __init__(
input_size: tuple[int, int],
model_size: EfficientAdModelSize = EfficientAdModelSize.S,
padding=False,
pad_maps=True,
) -> None:
super().__init__()

self.pad_maps = pad_maps
self.padding = padding
self.teacher: PDN_M | PDN_S
self.student: PDN_M | PDN_S

Expand Down Expand Up @@ -340,9 +337,13 @@ def forward(self, batch: Tensor, batch_imagenet: Tensor = None) -> Tensor | dict
(ae_output - student_output[:, self.teacher_out_channels :]) ** 2, dim=1, keepdim=True
)

if self.pad_maps:
if not self.padding:
# when the teacher/student architecture does not use padding, the output anomaly maps
# are smaller than if they had it, so the maps are misaligned. To fix this, we pad
# score maps by 4 pixels on each side. See github.com/openvinotoolkit/anomalib/discussions/1368
map_st = F.pad(map_st, (4, 4, 4, 4))
map_stae = F.pad(map_stae, (4, 4, 4, 4))

map_st = F.interpolate(map_st, size=(self.input_size[0], self.input_size[1]), mode="bilinear")
map_stae = F.interpolate(map_stae, size=(self.input_size[0], self.input_size[1]), mode="bilinear")

Expand Down