-
Notifications
You must be signed in to change notification settings - Fork 78
/
eval.py
218 lines (201 loc) · 7.73 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
Run evaluation on a trained model to get mAP and class wise AP.
USAGE:
python eval.py --data data_configs/voc.yaml --weights outputs/training/fasterrcnn_convnext_small_voc_15e_noaug/best_model.pth --model fasterrcnn_convnext_small
"""
from datasets import (
create_valid_dataset, create_valid_loader
)
from models.create_fasterrcnn_model import create_model
from torch_utils import utils
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from pprint import pprint
from tqdm import tqdm
import torch
import argparse
import yaml
import torchvision
import time
import numpy as np
torch.multiprocessing.set_sharing_strategy('file_system')
if __name__ == '__main__':
# Construct the argument parser.
parser = argparse.ArgumentParser()
parser.add_argument(
'--data',
default='data_configs/test_image_config.yaml',
help='(optional) path to the data config file'
)
parser.add_argument(
'-m', '--model',
default='fasterrcnn_resnet50_fpn',
help='name of the model'
)
parser.add_argument(
'-mw', '--weights',
default=None,
help='path to trained checkpoint weights if providing custom YAML file'
)
parser.add_argument(
'-ims', '--imgsz',
default=640,
type=int,
help='image size to feed to the network'
)
parser.add_argument(
'-w', '--workers', default=4, type=int,
help='number of workers for data processing/transforms/augmentations'
)
parser.add_argument(
'-b', '--batch',
default=8,
type=int,
help='batch size to load the data'
)
parser.add_argument(
'-d', '--device',
default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
help='computation/training device, default is GPU if GPU present'
)
parser.add_argument(
'-v', '--verbose',
action='store_true',
help='show class-wise mAP'
)
parser.add_argument(
'-st', '--square-training',
dest='square_training',
action='store_true',
help='Resize images to square shape instead of aspect ratio resizing \
for single image training. For mosaic training, this resizes \
single images to square shape first then puts them on a \
square canvas.'
)
args = vars(parser.parse_args())
# Load the data configurations
with open(args['data']) as file:
data_configs = yaml.safe_load(file)
# Validation settings and constants.
try: # Use test images if present.
VALID_DIR_IMAGES = data_configs['TEST_DIR_IMAGES']
VALID_DIR_LABELS = data_configs['TEST_DIR_LABELS']
except: # Else use the validation images.
VALID_DIR_IMAGES = data_configs['VALID_DIR_IMAGES']
VALID_DIR_LABELS = data_configs['VALID_DIR_LABELS']
NUM_CLASSES = data_configs['NC']
CLASSES = data_configs['CLASSES']
NUM_WORKERS = args['workers']
DEVICE = args['device']
BATCH_SIZE = args['batch']
# Model configurations
IMAGE_SIZE = args['imgsz']
# Load the pretrained model
create_model = create_model[args['model']]
if args['weights'] is None:
try:
model, coco_model = create_model(num_classes=NUM_CLASSES, coco_model=True)
except:
model = create_model(num_classes=NUM_CLASSES, coco_model=True)
if coco_model:
COCO_91_CLASSES = data_configs['COCO_91_CLASSES']
valid_dataset = create_valid_dataset(
VALID_DIR_IMAGES,
VALID_DIR_LABELS,
IMAGE_SIZE,
COCO_91_CLASSES,
square_training=args['square_training']
)
# Load weights.
if args['weights'] is not None:
model = create_model(num_classes=NUM_CLASSES, coco_model=False)
checkpoint = torch.load(args['weights'], map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
valid_dataset = create_valid_dataset(
VALID_DIR_IMAGES,
VALID_DIR_LABELS,
IMAGE_SIZE,
CLASSES,
square_training=args['square_training']
)
model.to(DEVICE).eval()
valid_loader = create_valid_loader(valid_dataset, BATCH_SIZE, NUM_WORKERS)
@torch.inference_mode()
def evaluate(
model,
data_loader,
device,
out_dir=None,
classes=None,
colors=None
):
metric = MeanAveragePrecision(class_metrics=args['verbose'])
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
torch.set_num_threads(1)
cpu_device = torch.device("cpu")
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
target = []
preds = []
counter = 0
for images, targets in tqdm(metric_logger.log_every(data_loader, 100, header), total=len(data_loader)):
counter += 1
images = list(img.to(device) for img in images)
if torch.cuda.is_available():
torch.cuda.synchronize()
model_time = time.time()
with torch.no_grad():
outputs = model(images)
#####################################
for i in range(len(images)):
true_dict = dict()
preds_dict = dict()
true_dict['boxes'] = targets[i]['boxes'].detach().cpu()
true_dict['labels'] = targets[i]['labels'].detach().cpu()
preds_dict['boxes'] = outputs[i]['boxes'].detach().cpu()
preds_dict['scores'] = outputs[i]['scores'].detach().cpu()
preds_dict['labels'] = outputs[i]['labels'].detach().cpu()
preds.append(preds_dict)
target.append(true_dict)
#####################################
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
# gather the stats from all processes
metric_logger.synchronize_between_processes()
torch.set_num_threads(n_threads)
metric.update(preds, target)
metric_summary = metric.compute()
return metric_summary
stats = evaluate(
model,
valid_loader,
device=DEVICE,
classes=CLASSES,
)
print('\n')
pprint(stats)
if args['verbose']:
print('\n')
pprint(f"Classes: {CLASSES}")
print('\n')
print('AP / AR per class')
empty_string = ''
if len(CLASSES) > 2:
num_hyphens = 73
print('-'*num_hyphens)
print(f"| | Class{empty_string:<16}| AP{empty_string:<18}| AR{empty_string:<18}|")
print('-'*num_hyphens)
class_counter = 0
for i in range(0, len(CLASSES)-1, 1):
class_counter += 1
print(f"|{class_counter:<3} | {CLASSES[i+1]:<20} | {np.array(stats['map_per_class'][i]):.3f}{empty_string:<15}| {np.array(stats['mar_100_per_class'][i]):.3f}{empty_string:<15}|")
print('-'*num_hyphens)
print(f"|Avg{empty_string:<23} | {np.array(stats['map']):.3f}{empty_string:<15}| {np.array(stats['mar_100']):.3f}{empty_string:<15}|")
else:
num_hyphens = 62
print('-'*num_hyphens)
print(f"|Class{empty_string:<10} | AP{empty_string:<18}| AR{empty_string:<18}|")
print('-'*num_hyphens)
print(f"|{CLASSES[1]:<15} | {np.array(stats['map']):.3f}{empty_string:<15}| {np.array(stats['mar_100']):.3f}{empty_string:<15}|")
print('-'*num_hyphens)
print(f"|Avg{empty_string:<12} | {np.array(stats['map']):.3f}{empty_string:<15}| {np.array(stats['mar_100']):.3f}{empty_string:<15}|")