-
Notifications
You must be signed in to change notification settings - Fork 174
/
eval.py
36 lines (29 loc) · 1.13 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import csv
import torch
from validate import validate
from networks.resnet import resnet50
from options.test_options import TestOptions
from eval_config import *
# Running tests
opt = TestOptions().parse(print_options=False)
model_name = os.path.basename(model_path).replace('.pth', '')
rows = [["{} model testing on...".format(model_name)],
['testset', 'accuracy', 'avg precision']]
print("{} model testing on...".format(model_name))
for v_id, val in enumerate(vals):
opt.dataroot = '{}/{}'.format(dataroot, val)
opt.classes = os.listdir(opt.dataroot) if multiclass[v_id] else ['']
opt.no_resize = True # testing without resizing by default
model = resnet50(num_classes=1)
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
model.cuda()
model.eval()
acc, ap, _, _, _, _ = validate(model, opt)
rows.append([val, acc, ap])
print("({}) acc: {}; ap: {}".format(val, acc, ap))
csv_name = results_dir + '/{}.csv'.format(model_name)
with open(csv_name, 'w') as f:
csv_writer = csv.writer(f, delimiter=',')
csv_writer.writerows(rows)