Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 120 additions & 81 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,54 @@
import glob
import imgaug.augmenters as iaa
from perlin import rand_perlin_2d_np
##
# added path error handling
# added ram preloading for anomalous images in dataset init to speed up training
##

class MVTecDRAEMTestDataset(Dataset):

def __init__(self, root_dir, resize_shape=None):
self.root_dir = root_dir
self.images = sorted(glob.glob(root_dir+"/*/*.png"))
self.resize_shape=resize_shape
print(f"Initializing dataset with root_dir: {root_dir}")
self.root_dir = os.path.abspath(root_dir)

if not os.path.exists(self.root_dir):
print(f"ERROR: Dataset folder does NOT exist: {self.root_dir}")
else:
print(f" Dataset folder exists: {self.root_dir}")

self.image_paths = sorted(
glob.glob(os.path.join(self.root_dir, "**", "*.jpg"), recursive=True) +
glob.glob(os.path.join(self.root_dir, "**", "*.png"), recursive=True)
)

if len(self.image_paths) == 0:
print(f" No images found in {self.root_dir}!")
self.image_paths = [os.path.abspath(path) for path in self.image_paths]

print(f"Found {len(self.image_paths)} images in {self.root_dir}.")

self.resize_shape = resize_shape


def __len__(self):
return len(self.images)
return len(self.image_paths)

def transform_image(self, image_path, mask_path):
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if mask_path is not None:

if image is None:
raise FileNotFoundError(f"Failed to load image from {image_path}")

if mask_path is not None and os.path.exists(mask_path):
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
print(f"Failed to load mask from {mask_path}, using a blank mask instead.")
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
else:
mask = np.zeros((image.shape[0],image.shape[1]))
if self.resize_shape != None:
print(f"Mask file not found for {image_path}, using a blank mask.")
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

if self.resize_shape is not None:
image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))
mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0]))

Expand All @@ -41,77 +71,90 @@ def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_path = self.images[idx]
img_path = self.image_paths[idx]
dir_path, file_name = os.path.split(img_path)
base_dir = os.path.basename(dir_path)
base_dir = os.path.basename(dir_path)

if base_dir == 'good':
image, mask = self.transform_image(img_path, None)
image, _ = self.transform_image(img_path, None)
mask = np.zeros((1, self.resize_shape[0], self.resize_shape[1]), dtype=np.float32) # Empty mask
has_anomaly = np.array([0], dtype=np.float32)
else:
mask_path = os.path.join(dir_path, '../../ground_truth/')
mask_path = os.path.join(mask_path, base_dir)
mask_file_name = file_name.split(".")[0]+"_mask.png"
mask_path = os.path.join(mask_path, mask_file_name)
image, mask = self.transform_image(img_path, mask_path)
file_base = os.path.splitext(file_name)[0]
mask_path = os.path.abspath(os.path.join(
os.path.dirname(self.root_dir),
"ground_truth",
base_dir,
f"{file_base}_mask.png"
))
if not os.path.exists(mask_path):
print(f" WARNING: Mask not found at {mask_path}. Using empty mask.")
image, _ = self.transform_image(img_path, None)
mask = np.zeros((1, self.resize_shape[0], self.resize_shape[1]), dtype=np.float32)
else:
image, mask = self.transform_image(img_path, mask_path)

has_anomaly = np.array([1], dtype=np.float32)

sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx}

sample = {'image': image, 'has_anomaly': has_anomaly, 'mask': mask, 'idx': idx}
return sample



class MVTecDRAEMTrainDataset(Dataset):

def __init__(self, root_dir, anomaly_source_path, resize_shape=None):
"""
Args:
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
anomaly_source_path (string): Directory with anomaly images.
resize_shape (tuple): Resize images to this shape.
"""
self.root_dir = root_dir
self.resize_shape=resize_shape

self.image_paths = sorted(glob.glob(root_dir+"/*.png"))

self.anomaly_source_paths = sorted(glob.glob(anomaly_source_path+"/*/*.jpg"))

self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True),
iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
iaa.pillike.EnhanceSharpness(),
iaa.AddToHueAndSaturation((-50,50),per_channel=True),
iaa.Solarize(0.5, threshold=(32,128)),
iaa.Posterize(),
iaa.Invert(),
iaa.pillike.Autocontrast(),
iaa.pillike.Equalize(),
iaa.Affine(rotate=(-45, 45))
]

self.resize_shape = resize_shape

self.image_paths = sorted(
glob.glob(os.path.join(root_dir, "*.jpg"), recursive=True) +
glob.glob(os.path.join(root_dir, "*.png"), recursive=True)
)

self.anomaly_source_paths = sorted(
glob.glob(os.path.join(anomaly_source_path, "*/*.jpg")) +
glob.glob(os.path.join(anomaly_source_path, "*/*.png"))
)

print(f"Preloading {len(self.image_paths)} normal images into RAM")
self.image_cache = {p: cv2.resize(cv2.imread(p), (resize_shape[1], resize_shape[0])) for p in self.image_paths}

print(f"Preloading {len(self.anomaly_source_paths)} anomaly images into RAM")
self.anomaly_cache = {p: cv2.resize(cv2.imread(p), (resize_shape[1], resize_shape[0])) for p in self.anomaly_source_paths}

print(f" Done, loaded {len(self.image_cache)} normal images and {len(self.anomaly_cache)} anomaly images.")

# Augmenters
self.augmenters = [
iaa.GammaContrast((0.5,2.0), per_channel=True),
iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2), add=(-30,30)),
iaa.pillike.EnhanceSharpness(),
iaa.AddToHueAndSaturation((-50,50), per_channel=True),
iaa.Solarize(0.5, threshold=(32,128)),
iaa.Posterize(),
iaa.Invert(),
iaa.pillike.Autocontrast(),
iaa.pillike.Equalize(),
iaa.Affine(rotate=(-45, 45))
]
self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])


def __len__(self):
return len(self.image_paths)


def randAugmenter(self):
aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
aug = iaa.Sequential([self.augmenters[aug_ind[0]],
self.augmenters[aug_ind[1]],
self.augmenters[aug_ind[2]]]
)
return aug
return iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]])

def augment_image(self, image, anomaly_source_path):
def augment_image(self, image, anomaly_image):
aug = self.randAugmenter()
perlin_scale = 6
min_perlin_scale = 0
anomaly_source_img = cv2.imread(anomaly_source_path)
anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(self.resize_shape[1], self.resize_shape[0]))

anomaly_img_augmented = aug(image=anomaly_source_img)
anomaly_img_augmented = aug(image=anomaly_image)
perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])

Expand All @@ -125,43 +168,39 @@ def augment_image(self, image, anomaly_source_path):

beta = torch.rand(1).numpy()[0] * 0.8

augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * (
perlin_thr)
augmented_image = image * (1 - perlin_thr) + (1 - beta) * img_thr + beta * image * perlin_thr

no_anomaly = torch.rand(1).numpy()[0]
if no_anomaly > 0.5:
image = image.astype(np.float32)
return image, np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0],dtype=np.float32)
return image.astype(np.float32), np.zeros_like(perlin_thr, dtype=np.float32), np.array([0.0], dtype=np.float32)
else:
augmented_image = augmented_image.astype(np.float32)
msk = (perlin_thr).astype(np.float32)
augmented_image = msk * augmented_image + (1-msk)*image
has_anomaly = 1.0
if np.sum(msk) == 0:
has_anomaly=0.0
return augmented_image, msk, np.array([has_anomaly],dtype=np.float32)

def transform_image(self, image_path, anomaly_source_path):
image = cv2.imread(image_path)
image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))

do_aug_orig = torch.rand(1).numpy()[0] > 0.7
if do_aug_orig:
image = self.rot(image=image)
msk = perlin_thr.astype(np.float32)
augmented_image = msk * augmented_image + (1-msk) * image
has_anomaly = np.array([1.0], dtype=np.float32) if np.sum(msk) > 0 else np.array([0.0], dtype=np.float32)
return augmented_image, msk, has_anomaly

image = np.array(image).reshape((image.shape[0], image.shape[1], image.shape[2])).astype(np.float32) / 255.0
augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_source_path)
def transform_image(self, image, anomaly_image):
if torch.rand(1).item() > 0.7:
image = self.rot(image=image)
image = image.astype(np.float32) / 255.0
augmented_image, anomaly_mask, has_anomaly = self.augment_image(image, anomaly_image)
augmented_image = np.transpose(augmented_image, (2, 0, 1))
image = np.transpose(image, (2, 0, 1))
anomaly_mask = np.transpose(anomaly_mask, (2, 0, 1))
return image, augmented_image, anomaly_mask, has_anomaly

def __getitem__(self, idx):
idx = torch.randint(0, len(self.image_paths), (1,)).item()
anomaly_source_idx = torch.randint(0, len(self.anomaly_source_paths), (1,)).item()
image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(self.image_paths[idx],
self.anomaly_source_paths[anomaly_source_idx])
sample = {'image': image, "anomaly_mask": anomaly_mask,
'augmented_image': augmented_image, 'has_anomaly': has_anomaly, 'idx': idx}

return sample
image_path = self.image_paths[idx]
anomaly_path = np.random.choice(self.anomaly_source_paths)
image = self.image_cache[image_path]
anomaly_image = self.anomaly_cache[anomaly_path]
image, augmented_image, anomaly_mask, has_anomaly = self.transform_image(image, anomaly_image)

return {
'image': image,
'augmented_image': augmented_image,
'anomaly_mask': anomaly_mask,
'has_anomaly': has_anomaly,
'idx': idx
}