@@ -104,7 +104,11 @@ def evaluate(self, image: bytes) -> ImageMetadata:
104
104
x , y , br_x , br_y = box
105
105
w = br_x - x
106
106
h = br_y - y
107
- category_idx = int (result_classid [j ])
107
+ category_idx = result_classid [j ]
108
+ if category_idx < 0 or category_idx >= len (self .model_info .categories ):
109
+ self .log .warning ('invalid category index: %d for %d classes' ,
110
+ category_idx , len (self .model_info .categories ))
111
+ continue
108
112
category = self .model_info .categories [category_idx ]
109
113
probability = round (float (result_scores [j ]), 2 )
110
114
@@ -190,7 +194,8 @@ def _post_process(self, pred, origin_h, origin_w, conf_thres, nms_thres):
190
194
"""
191
195
description: postprocess the prediction
192
196
param:
193
- output: A numpy likes [[cx,cy,w,h,conf,cls_id], [cx,cy,w,h,conf,cls_id], ...]
197
+ pred: A numpy likes [[cx,cy,w,h,conf, c0_prob, c1_prob, ...],
198
+ [cx,cy,w,h,conf, c0_prob, c1_prob, ...], ...]
194
199
origin_h: height of original image
195
200
origin_w: width of original image
196
201
conf_thres: confidence threshold
@@ -201,20 +206,26 @@ def _post_process(self, pred, origin_h, origin_w, conf_thres, nms_thres):
201
206
result_classid: finally classid, a numpy, each element is the classid correspoing to box
202
207
"""
203
208
209
+ num_classes = pred .shape [1 ] - 5
210
+
204
211
# Do nms
205
212
boxes = self ._non_max_suppression (
206
213
pred , origin_h , origin_w , conf_thres , nms_thres )
207
214
result_boxes = boxes [:, :4 ] if len (boxes ) else np .array ([])
208
215
result_scores = boxes [:, 4 ] if len (boxes ) else np .array ([])
209
- result_classid = boxes [:, 5 ] if len (boxes ) else np .array ([])
216
+ if num_classes > 1 :
217
+ result_classid = np .argmax (boxes [:, 5 :], axis = 1 )
218
+ else :
219
+ result_classid = np .zeros (boxes .shape [0 ], dtype = int )
210
220
return result_boxes , result_scores , result_classid
211
221
212
222
def _non_max_suppression (self , pred , origin_h , origin_w , conf_thres , nms_thres ):
213
223
"""
214
224
description: Removes detections with lower object confidence score than 'conf_thres' and performs
215
225
Non-Maximum Suppression to further filter detections.
216
226
param:
217
- prediction: A numpy likes [[cx,cy,w,h,conf,cls_id], [cx,cy,w,h,conf,cls_id], ...]
227
+ prediction: A numpy likes [[cx,cy,w,h,conf, c0_prob, c1_prob, ...],
228
+ [cx,cy,w,h,conf, c0_prob, c1_prob, ...], ...]
218
229
origin_h: original image height
219
230
origin_w: original image width
220
231
input_size: the input size of the model
@@ -225,6 +236,9 @@ def _non_max_suppression(self, pred, origin_h, origin_w, conf_thres, nms_thres):
225
236
"""
226
237
# Get the boxes that score > CONF_THRESH
227
238
boxes = pred [pred [:, 4 ] >= conf_thres ]
239
+ if len (boxes ) == 0 :
240
+ return np .array ([])
241
+ num_classes = boxes .shape [1 ] - 5
228
242
# Trasform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
229
243
boxes [:, :4 ] = self .xywh2xyxy (origin_h , origin_w , boxes [:, :4 ])
230
244
# Object confidence
@@ -236,7 +250,10 @@ def _non_max_suppression(self, pred, origin_h, origin_w, conf_thres, nms_thres):
236
250
while boxes .shape [0 ]:
237
251
large_overlap = self .bbox_iou (np .expand_dims (
238
252
boxes [0 , :4 ], 0 ), boxes [:, :4 ]) > nms_thres
239
- label_match = boxes [0 , - 1 ].astype (int ) == boxes [:, - 1 ].astype (int )
253
+ if num_classes > 1 :
254
+ label_match = np .argmax (boxes [:, 5 :], axis = 1 ) == np .argmax (boxes [0 , 5 :])
255
+ else :
256
+ label_match = np .ones (boxes .shape [0 ], dtype = bool )
240
257
# Indices of boxes with lower confidence scores, large IOUs and matching labels
241
258
invalid = large_overlap & label_match
242
259
keep_boxes += [boxes [0 ]]
0 commit comments