diff --git a/README.md b/README.md index 7e56b60..c61d36d 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,16 @@ to download the MVTec and the DTD datasets to the **datasets** folder in the pro ``` ./scripts/download_dataset.sh ``` +if downlaod_dataset.sh or pretrained.sh scripts throw "permission denies" error then use these commands +``` +chmod +x ./scripts/download_dataset.sh +``` +now download the dataset again + +``` +./scripts/download_dataset.sh +``` ## Training Pass the folder containing the training dataset to the **train_DRAEM.py** script as the --data_path argument and the @@ -51,4 +60,18 @@ with pretrained models can be run with: python test_DRAEM.py --gpu_id 0 --base_model_name "DRAEM_seg_large_ae_large_0.0001_800_bs8" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/DRAEM_checkpoints/ ``` +## Inference +For Inference Dataset loader was modified to load only two images per class and Heatmap of Anomaly would be predicted by the model and displayed using Matplotlib. + +The inference script requires the --gpu_id arguments, the name of the checkpoint files (--base_model_name) for trained models, the +location of the MVTec anomaly detection dataset (--data_path) and the folder where the checkpoint files are located (--checkpoint_path) +with pretrained models can be run with: + +``` +python visualize_DRAEM.py --gpu_id 0 --base_model_name "DRAEM_seg_large_ae_large_0.0001_800_bs8" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/DRAEM_checkpoints/ +``` +## Inference Results +![Screenshot](images/result_1.PNG) +![Screenshot](images/result_2.PNG) + diff --git a/data_loader.py b/data_loader.py index 5e306a1..09decc0 100644 --- a/data_loader.py +++ b/data_loader.py @@ -60,6 +60,60 @@ def __getitem__(self, idx): return sample +class MVTecDRAEM_Test_Visual_Dataset(Dataset): + + def __init__(self, root_dir, resize_shape=None): + self.root_dir = root_dir + self.images = sorted(glob.glob(root_dir+"/*/*.png"))[:2] + self.resize_shape=resize_shape + + def __len__(self): + return len(self.images) + + def transform_image(self, image_path, mask_path): + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + if mask_path is not None: + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + else: + mask = np.zeros((image.shape[0],image.shape[1])) + if self.resize_shape != None: + image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0])) + mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0])) + + image = image / 255.0 + mask = mask / 255.0 + + image = np.array(image).reshape((image.shape[0], image.shape[1], 3)).astype(np.float32) + mask = np.array(mask).reshape((mask.shape[0], mask.shape[1], 1)).astype(np.float32) + + image = np.transpose(image, (2, 0, 1)) + mask = np.transpose(mask, (2, 0, 1)) + return image, mask + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + img_path = self.images[idx] + dir_path, file_name = os.path.split(img_path) + base_dir = os.path.basename(dir_path) + if base_dir == 'good': + image, mask = self.transform_image(img_path, None) + has_anomaly = np.array([0], dtype=np.float32) + else: + mask_path = os.path.join(dir_path, '../../ground_truth/') + mask_path = os.path.join(mask_path, base_dir) + mask_file_name = file_name.split(".")[0]+"_mask.png" + mask_path = os.path.join(mask_path, mask_file_name) + image, mask = self.transform_image(img_path, mask_path) + has_anomaly = np.array([1], dtype=np.float32) + + sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx} + + return sample + + + class MVTecDRAEMTrainDataset(Dataset): diff --git a/images/result_1.PNG b/images/result_1.PNG new file mode 100644 index 0000000..301bc0d Binary files /dev/null and b/images/result_1.PNG differ diff --git a/images/result_2.PNG b/images/result_2.PNG new file mode 100644 index 0000000..24bacbd Binary files /dev/null and b/images/result_2.PNG differ diff --git a/visualize_DRAEM.py b/visualize_DRAEM.py new file mode 100644 index 0000000..7a419bd --- /dev/null +++ b/visualize_DRAEM.py @@ -0,0 +1,128 @@ +import torch +import torch.nn.functional as F +from data_loader import MVTecDRAEM_Test_Visual_Dataset +from torch.utils.data import DataLoader +import numpy as np +from sklearn.metrics import roc_auc_score, average_precision_score +from model_unet import ReconstructiveSubNetwork, DiscriminativeSubNetwork +import os +import matplotlib.pyplot as plt + + +def test(obj_names, mvtec_path, checkpoint_path, base_model_name): + obj_ap_pixel_list = [] + obj_auroc_pixel_list = [] + obj_ap_image_list = [] + obj_auroc_image_list = [] + for obj_name in obj_names: + img_dim = 256 + run_name = base_model_name+"_"+obj_name+'_' + + model = ReconstructiveSubNetwork(in_channels=3, out_channels=3) + model.load_state_dict(torch.load(os.path.join(checkpoint_path,run_name+".pckl"), map_location='cuda:0')) + model.cuda() + model.eval() + + model_seg = DiscriminativeSubNetwork(in_channels=6, out_channels=2) + model_seg.load_state_dict(torch.load(os.path.join(checkpoint_path, run_name+"_seg.pckl"), map_location='cuda:0')) + model_seg.cuda() + model_seg.eval() + + dataset = MVTecDRAEM_Test_Visual_Dataset(mvtec_path + obj_name + "/test/", resize_shape=[img_dim, img_dim]) + dataloader = DataLoader(dataset, batch_size=1, + shuffle=False, num_workers=0) + + total_pixel_scores = np.zeros((img_dim * img_dim * len(dataset))) + total_gt_pixel_scores = np.zeros((img_dim * img_dim * len(dataset))) + mask_cnt = 0 + + anomaly_score_gt = [] + anomaly_score_prediction = [] + + display_images = torch.zeros((16 ,3 ,256 ,256)).cuda() + display_gt_images = torch.zeros((16 ,3 ,256 ,256)).cuda() + display_out_masks = torch.zeros((16 ,1 ,256 ,256)).cuda() + display_in_masks = torch.zeros((16 ,1 ,256 ,256)).cuda() + cnt_display = 0 + display_indices = np.random.randint(len(dataloader), size=(16,)) + + + for i_batch, sample_batched in enumerate(dataloader): + + gray_batch = sample_batched["image"].cuda() + # Convert tensor to a numpy array and move it to the CPU + image = gray_batch.permute(0, 2, 3, 1).cpu().numpy() + + # Display all images in the batch + for i in range(image.shape[0]): + plt.imshow(image[i], cmap='gray') + plt.title('Original Image') + plt.show() + is_normal = sample_batched["has_anomaly"].detach().numpy()[0 ,0] + anomaly_score_gt.append(is_normal) + true_mask = sample_batched["mask"] + true_mask_cv = true_mask.detach().numpy()[0, :, :, :].transpose((1, 2, 0)) + + gray_rec = model(gray_batch) + joined_in = torch.cat((gray_rec.detach(), gray_batch), dim=1) + + out_mask = model_seg(joined_in) + out_mask_sm = torch.softmax(out_mask, dim=1) + + + if i_batch in display_indices: + t_mask = out_mask_sm[:, 1:, :, :] + display_images[cnt_display] = gray_rec[0].cpu().detach() + display_gt_images[cnt_display] = gray_batch[0].cpu().detach() + display_out_masks[cnt_display] = t_mask[0].cpu().detach() + display_in_masks[cnt_display] = true_mask[0].cpu().detach() + cnt_display += 1 + + out_mask_cv = out_mask_sm[0 ,1 ,: ,:].detach().cpu().numpy() + plt.imshow(out_mask_cv) + plt.title('Predicted Anomaly Heatmap') + plt.show() + + out_mask_averaged = torch.nn.functional.avg_pool2d(out_mask_sm[: ,1: ,: ,:], 21, stride=1, + padding=21 // 2).cpu().detach().numpy() + image_score = np.max(out_mask_averaged) + + anomaly_score_prediction.append(image_score) + + flat_true_mask = true_mask_cv.flatten() + flat_out_mask = out_mask_cv.flatten() + total_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_out_mask + total_gt_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_true_mask + mask_cnt += 1 + + +if __name__=="__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--gpu_id', action='store', type=int, required=True) + parser.add_argument('--base_model_name', action='store', type=str, required=True) + parser.add_argument('--data_path', action='store', type=str, required=True) + parser.add_argument('--checkpoint_path', action='store', type=str, required=True) + + args = parser.parse_args() + + obj_list = ['capsule', + 'bottle', + 'carpet', + 'leather', + 'pill', + 'transistor', + 'tile', + 'cable', + 'zipper', + 'toothbrush', + 'metal_nut', + 'hazelnut', + 'screw', + 'grid', + 'wood' + ] + + with torch.cuda.device(args.gpu_id): + test(obj_list,args.data_path, args.checkpoint_path, args.base_model_name)