Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ to download the MVTec and the DTD datasets to the **datasets** folder in the pro
```
./scripts/download_dataset.sh
```
if downlaod_dataset.sh or pretrained.sh scripts throw "permission denies" error then use these commands

```
chmod +x ./scripts/download_dataset.sh
```
now download the dataset again

```
./scripts/download_dataset.sh
```

## Training
Pass the folder containing the training dataset to the **train_DRAEM.py** script as the --data_path argument and the
Expand Down Expand Up @@ -51,4 +60,18 @@ with pretrained models can be run with:
python test_DRAEM.py --gpu_id 0 --base_model_name "DRAEM_seg_large_ae_large_0.0001_800_bs8" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/DRAEM_checkpoints/
```

## Inference
For Inference Dataset loader was modified to load only two images per class and Heatmap of Anomaly would be predicted by the model and displayed using Matplotlib.

The inference script requires the --gpu_id arguments, the name of the checkpoint files (--base_model_name) for trained models, the
location of the MVTec anomaly detection dataset (--data_path) and the folder where the checkpoint files are located (--checkpoint_path)
with pretrained models can be run with:

```
python visualize_DRAEM.py --gpu_id 0 --base_model_name "DRAEM_seg_large_ae_large_0.0001_800_bs8" --data_path ./datasets/mvtec/ --checkpoint_path ./checkpoints/DRAEM_checkpoints/
```
## Inference Results
![Screenshot](images/result_1.PNG)
![Screenshot](images/result_2.PNG)


54 changes: 54 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,60 @@ def __getitem__(self, idx):
return sample


class MVTecDRAEM_Test_Visual_Dataset(Dataset):

def __init__(self, root_dir, resize_shape=None):
self.root_dir = root_dir
self.images = sorted(glob.glob(root_dir+"/*/*.png"))[:2]
self.resize_shape=resize_shape

def __len__(self):
return len(self.images)

def transform_image(self, image_path, mask_path):
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if mask_path is not None:
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
else:
mask = np.zeros((image.shape[0],image.shape[1]))
if self.resize_shape != None:
image = cv2.resize(image, dsize=(self.resize_shape[1], self.resize_shape[0]))
mask = cv2.resize(mask, dsize=(self.resize_shape[1], self.resize_shape[0]))

image = image / 255.0
mask = mask / 255.0

image = np.array(image).reshape((image.shape[0], image.shape[1], 3)).astype(np.float32)
mask = np.array(mask).reshape((mask.shape[0], mask.shape[1], 1)).astype(np.float32)

image = np.transpose(image, (2, 0, 1))
mask = np.transpose(mask, (2, 0, 1))
return image, mask

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img_path = self.images[idx]
dir_path, file_name = os.path.split(img_path)
base_dir = os.path.basename(dir_path)
if base_dir == 'good':
image, mask = self.transform_image(img_path, None)
has_anomaly = np.array([0], dtype=np.float32)
else:
mask_path = os.path.join(dir_path, '../../ground_truth/')
mask_path = os.path.join(mask_path, base_dir)
mask_file_name = file_name.split(".")[0]+"_mask.png"
mask_path = os.path.join(mask_path, mask_file_name)
image, mask = self.transform_image(img_path, mask_path)
has_anomaly = np.array([1], dtype=np.float32)

sample = {'image': image, 'has_anomaly': has_anomaly,'mask': mask, 'idx': idx}

return sample




class MVTecDRAEMTrainDataset(Dataset):

Expand Down
Binary file added images/result_1.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/result_2.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
128 changes: 128 additions & 0 deletions visualize_DRAEM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.nn.functional as F
from data_loader import MVTecDRAEM_Test_Visual_Dataset
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from model_unet import ReconstructiveSubNetwork, DiscriminativeSubNetwork
import os
import matplotlib.pyplot as plt


def test(obj_names, mvtec_path, checkpoint_path, base_model_name):
obj_ap_pixel_list = []
obj_auroc_pixel_list = []
obj_ap_image_list = []
obj_auroc_image_list = []
for obj_name in obj_names:
img_dim = 256
run_name = base_model_name+"_"+obj_name+'_'

model = ReconstructiveSubNetwork(in_channels=3, out_channels=3)
model.load_state_dict(torch.load(os.path.join(checkpoint_path,run_name+".pckl"), map_location='cuda:0'))
model.cuda()
model.eval()

model_seg = DiscriminativeSubNetwork(in_channels=6, out_channels=2)
model_seg.load_state_dict(torch.load(os.path.join(checkpoint_path, run_name+"_seg.pckl"), map_location='cuda:0'))
model_seg.cuda()
model_seg.eval()

dataset = MVTecDRAEM_Test_Visual_Dataset(mvtec_path + obj_name + "/test/", resize_shape=[img_dim, img_dim])
dataloader = DataLoader(dataset, batch_size=1,
shuffle=False, num_workers=0)

total_pixel_scores = np.zeros((img_dim * img_dim * len(dataset)))
total_gt_pixel_scores = np.zeros((img_dim * img_dim * len(dataset)))
mask_cnt = 0

anomaly_score_gt = []
anomaly_score_prediction = []

display_images = torch.zeros((16 ,3 ,256 ,256)).cuda()
display_gt_images = torch.zeros((16 ,3 ,256 ,256)).cuda()
display_out_masks = torch.zeros((16 ,1 ,256 ,256)).cuda()
display_in_masks = torch.zeros((16 ,1 ,256 ,256)).cuda()
cnt_display = 0
display_indices = np.random.randint(len(dataloader), size=(16,))


for i_batch, sample_batched in enumerate(dataloader):

gray_batch = sample_batched["image"].cuda()
# Convert tensor to a numpy array and move it to the CPU
image = gray_batch.permute(0, 2, 3, 1).cpu().numpy()

# Display all images in the batch
for i in range(image.shape[0]):
plt.imshow(image[i], cmap='gray')
plt.title('Original Image')
plt.show()
is_normal = sample_batched["has_anomaly"].detach().numpy()[0 ,0]
anomaly_score_gt.append(is_normal)
true_mask = sample_batched["mask"]
true_mask_cv = true_mask.detach().numpy()[0, :, :, :].transpose((1, 2, 0))

gray_rec = model(gray_batch)
joined_in = torch.cat((gray_rec.detach(), gray_batch), dim=1)

out_mask = model_seg(joined_in)
out_mask_sm = torch.softmax(out_mask, dim=1)


if i_batch in display_indices:
t_mask = out_mask_sm[:, 1:, :, :]
display_images[cnt_display] = gray_rec[0].cpu().detach()
display_gt_images[cnt_display] = gray_batch[0].cpu().detach()
display_out_masks[cnt_display] = t_mask[0].cpu().detach()
display_in_masks[cnt_display] = true_mask[0].cpu().detach()
cnt_display += 1

out_mask_cv = out_mask_sm[0 ,1 ,: ,:].detach().cpu().numpy()
plt.imshow(out_mask_cv)
plt.title('Predicted Anomaly Heatmap')
plt.show()

out_mask_averaged = torch.nn.functional.avg_pool2d(out_mask_sm[: ,1: ,: ,:], 21, stride=1,
padding=21 // 2).cpu().detach().numpy()
image_score = np.max(out_mask_averaged)

anomaly_score_prediction.append(image_score)

flat_true_mask = true_mask_cv.flatten()
flat_out_mask = out_mask_cv.flatten()
total_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_out_mask
total_gt_pixel_scores[mask_cnt * img_dim * img_dim:(mask_cnt + 1) * img_dim * img_dim] = flat_true_mask
mask_cnt += 1


if __name__=="__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', action='store', type=int, required=True)
parser.add_argument('--base_model_name', action='store', type=str, required=True)
parser.add_argument('--data_path', action='store', type=str, required=True)
parser.add_argument('--checkpoint_path', action='store', type=str, required=True)

args = parser.parse_args()

obj_list = ['capsule',
'bottle',
'carpet',
'leather',
'pill',
'transistor',
'tile',
'cable',
'zipper',
'toothbrush',
'metal_nut',
'hazelnut',
'screw',
'grid',
'wood'
]

with torch.cuda.device(args.gpu_id):
test(obj_list,args.data_path, args.checkpoint_path, args.base_model_name)