-
Notifications
You must be signed in to change notification settings - Fork 2
/
validate.py
114 lines (93 loc) · 3.99 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from torch.utils.data import DataLoader
from configs import config_dict as config_dict
from datasets import dataset_dict as dataset_dict
from torchvision.utils import save_image
import torch
import numpy as np
import time
import argparse
import os
from CaRTS import build_model
from CaRTS.evaluation.metrics import dice_scores, normalized_surface_distances
from evaluate import plot_segmentation, plot_augmentation
import matplotlib.pyplot as plt
import matplotlib.cm as cm
def parse_args():
parser = argparse.ArgumentParser(epilog="Example of usage:\n python validate.py --config UNet_SegSTRONGC --model_path checkpoints/unet_segstrongc/model_39.pth --test True --domain regular")
parser.add_argument("--config", type=str, help="Name of the config file")
parser.add_argument("--model_path", type=str, help="Path to the model checkpoint file")
parser.add_argument("--test", type=bool, default=False, help="True for testing, False for validation")
parser.add_argument("--domain", type=str, default=None, choices=['regular', 'smoke', 'bg_change', 'blood', 'low_brightness'], help="Test/Validate domain")
parser.add_argument("--save_dir", type=str, default=None, help="Path to save model output")
parser.add_argument("--tau", type=int, default=5, help="Tolerance in normalized surface distance calculation")
args = parser.parse_args()
return args
def evaluate(model, dataloader, device, tau, save_dir=None):
start = time.time()
results = []
gts = []
dice_tools = []
nsds = []
model.eval()
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i, (image, gt) in enumerate(dataloader):
print("Iteration: ", i, "/", len(dataloader), end="\r")
data = dict()
data['image'] = image.to(device=device)
data['gt'] = gt.to(device=device)
data['iteration'] = i
pred = model(data)['pred']
result = (pred[0].cpu().detach().numpy() > 0.5).squeeze()
results.append(result)
mask = (data['gt'].cpu().numpy() > 0.5).squeeze()
# plot_segmentation(result, mask, image, i)
# plot_augmentation(image.squeeze().to(torch.uint8), i)
dice_tool = dice_scores(result, mask)
nsd = normalized_surface_distances(result, mask, tau)
dice_tools.append(dice_tool)
nsds.append(nsd)
if save_dir is not None:
results.append(result)
elapsed = time.time() - start
print("iteration per Sec: %f" %
((i+1) / elapsed))
print("mean: dice_tool: %f " %
(np.mean([dice_tools])))
print("std: dice_tool: %f " %
(np.std([dice_tools])))
print("mean: nsd: %f" %
(np.mean([nsds])))
print("std: nsd: %f" %
(np.std([nsds])))
if save_dir is not None:
np.save(os.path.join(save_dir, "pred.npy"), results)
if __name__ == "__main__":
args = parse_args()
cfg = config_dict[args.config]
use_gpu = torch.cuda.is_available()
if use_gpu:
print("use_gpu")
device = torch.device("cuda")
else:
device = torch.device("cpu")
cfg.validation_dataset['args']['domains'] = [args.domain]
if args.domain is not None:
domain = args.domain
if args.test:
cfg.test_dataset['args']['domains'] = [domain]
else:
cfg.validation_dataset['args']['domains'] = [domain]
dataset = None
datatloader = None
save_dir = args.save_dir
if args.test:
dataset = dataset_dict[cfg.test_dataset['name']](**(cfg.test_dataset['args']))
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
else:
dataset = dataset_dict[cfg.validation_dataset['name']](**(cfg.validation_dataset['args']))
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
model = build_model(cfg.model, device)
model.load_parameters(args.model_path)
evaluate(model, dataloader, device, args.tau, save_dir)