Skip to content

Commit 671a392

Browse files
committed
fix documentation and dealing with multiple categories
1 parent 5c47df1 commit 671a392

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

detector_cpu/yolov5_detector.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ def evaluate(self, image: bytes) -> ImageMetadata:
104104
x, y, br_x, br_y = box
105105
w = br_x - x
106106
h = br_y - y
107-
category_idx = int(result_classid[j])
107+
category_idx = result_classid[j]
108+
if category_idx < 0 or category_idx >= len(self.model_info.categories):
109+
self.log.warning('invalid category index: %d for %d classes',
110+
category_idx, len(self.model_info.categories))
111+
continue
108112
category = self.model_info.categories[category_idx]
109113
probability = round(float(result_scores[j]), 2)
110114

@@ -190,7 +194,8 @@ def _post_process(self, pred, origin_h, origin_w, conf_thres, nms_thres):
190194
"""
191195
description: postprocess the prediction
192196
param:
193-
output: A numpy likes [[cx,cy,w,h,conf,cls_id], [cx,cy,w,h,conf,cls_id], ...]
197+
pred: A numpy likes [[cx,cy,w,h,conf, c0_prob, c1_prob, ...],
198+
[cx,cy,w,h,conf, c0_prob, c1_prob, ...], ...]
194199
origin_h: height of original image
195200
origin_w: width of original image
196201
conf_thres: confidence threshold
@@ -201,20 +206,26 @@ def _post_process(self, pred, origin_h, origin_w, conf_thres, nms_thres):
201206
result_classid: finally classid, a numpy, each element is the classid correspoing to box
202207
"""
203208

209+
num_classes = pred.shape[1] - 5
210+
204211
# Do nms
205212
boxes = self._non_max_suppression(
206213
pred, origin_h, origin_w, conf_thres, nms_thres)
207214
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
208215
result_scores = boxes[:, 4] if len(boxes) else np.array([])
209-
result_classid = boxes[:, 5] if len(boxes) else np.array([])
216+
if num_classes > 1:
217+
result_classid = np.argmax(boxes[:, 5:], axis=1)
218+
else:
219+
result_classid = np.zeros(boxes.shape[0], dtype=int)
210220
return result_boxes, result_scores, result_classid
211221

212222
def _non_max_suppression(self, pred, origin_h, origin_w, conf_thres, nms_thres):
213223
"""
214224
description: Removes detections with lower object confidence score than 'conf_thres' and performs
215225
Non-Maximum Suppression to further filter detections.
216226
param:
217-
prediction: A numpy likes [[cx,cy,w,h,conf,cls_id], [cx,cy,w,h,conf,cls_id], ...]
227+
prediction: A numpy likes [[cx,cy,w,h,conf, c0_prob, c1_prob, ...],
228+
[cx,cy,w,h,conf, c0_prob, c1_prob, ...], ...]
218229
origin_h: original image height
219230
origin_w: original image width
220231
input_size: the input size of the model
@@ -225,6 +236,9 @@ def _non_max_suppression(self, pred, origin_h, origin_w, conf_thres, nms_thres):
225236
"""
226237
# Get the boxes that score > CONF_THRESH
227238
boxes = pred[pred[:, 4] >= conf_thres]
239+
if len(boxes) == 0:
240+
return np.array([])
241+
num_classes = boxes.shape[1] - 5
228242
# Trasform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
229243
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
230244
# Object confidence
@@ -236,7 +250,10 @@ def _non_max_suppression(self, pred, origin_h, origin_w, conf_thres, nms_thres):
236250
while boxes.shape[0]:
237251
large_overlap = self.bbox_iou(np.expand_dims(
238252
boxes[0, :4], 0), boxes[:, :4]) > nms_thres
239-
label_match = boxes[0, -1].astype(int) == boxes[:, -1].astype(int)
253+
if num_classes > 1:
254+
label_match = np.argmax(boxes[:, 5:], axis=1) == np.argmax(boxes[0, 5:])
255+
else:
256+
label_match = np.ones(boxes.shape[0], dtype=bool)
240257
# Indices of boxes with lower confidence scores, large IOUs and matching labels
241258
invalid = large_overlap & label_match
242259
keep_boxes += [boxes[0]]

0 commit comments

Comments
 (0)