Skip to content

Commit 43c30c6

Browse files
Configurable iou threshold (#24)
## Problem In certain projects, non-max suppression (NMS) is causing detection issues. By default, points are merged during inference at an IoU of 0.45. However, relevant objects may be only some px apart, preventing the model from detecting all objects. Simply reducing point size isn't viable as it would exclude small references during training. ## Changes - Made NMS IoU threshold configurable in both Trainer and Detector - Fixed bug in TensorRT-based YOLOv5 inference that prevented detection of boxes directly at the top or left image edges - Note: This edge detection bug did not affect point detection
1 parent 051048b commit 43c30c6

File tree

8 files changed

+156
-97
lines changed

8 files changed

+156
-97
lines changed

README.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ There are two variants of the detector:
3535
- to be deployed on a regular Linux computer, e.g. running Ubuntu (referred to as cloud-detectors)
3636
- to be deployed on a Jetson Nano running Linux4Tegra (L4T)
3737

38+
Mandatory parameters are those described in [Zauberzeug Learning Loop Node Library](https://github.com/zauberzeug/learning_loop_node).
39+
Besides, the following parameters may be set:
40+
41+
| Name | Purpose | Value | Default | Required only with ./docker.sh |
42+
| -------------- | ----------------------------------------- | ------------------------- | ------- | ------------------------------ |
43+
| LINKLL | Link the node library into the container? | TRUE or FALSE | FALSE | Yes |
44+
| DETECTOR_NAME | Will be the name of the container | String | - | Yes |
45+
| WEIGHT_TYPE | Data type to convert weights to | String [FP32, FP16, INT8] | FP16 | No |
46+
| IOU_THRESHOLD | IoU threshold for NMS | Float | 0.45 | No |
47+
| CONF_THRESHOLD | Confidence threshold for NMS | Float | 0.2 | No |
48+
3849
### Cloud-Detector
3950

4051
New images can be pulled with `docker pull zauberzeug/yolov5-detector:nlvX.Y.Z-cloud`, where `X.Y.Z` is the version of the node-lib used.
@@ -43,13 +54,6 @@ Legacy images can be pulled with `docker pull zauberzeug/yolov5-detector:cloud`.
4354
Pulled images can be run with the `docker.sh` script by calling `./docker.sh run-image`.
4455
Local builds can be run with `./docker.sh run`.
4556
If the container does not use the GPU, try `./docker.sh d`.
46-
Mandatory parameters are those described in [Zauberzeug Learning Loop Node Library](https://github.com/zauberzeug/learning_loop_node). Besides, the following parameters may be set:
47-
48-
| Name | Purpose | Value | Default | Required only with ./docker.sh |
49-
| ------------- | ----------------------------------------- | ------------------------- | ------- | ------------------------------ |
50-
| LINKLL | Link the node library into the container? | TRUE or FALSE | FALSE | Yes |
51-
| DETECTOR_NAME | Will be the name of the container | String | - | Yes |
52-
| WEIGHT_TYPE | Data type to convert weights to | String [FP32, FP16, INT8] | FP16 | No |
5357

5458
### L4T-Detector
5559

detector/docker.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ fi
2828

2929
# ========================== BUILD CONFIGURATION / IMAGE SELECTION =======================
3030

31-
SEMANTIC_VERSION=0.1.11
31+
SEMANTIC_VERSION=0.1.12
3232
NODE_LIB_VERSION=0.14.0
3333
build_args=" --build-arg NODE_LIB_VERSION=$NODE_LIB_VERSION"
3434

detector/yolov5.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
import cv2
1313
import numpy as np
14-
import pycuda.driver as cuda
15-
import tensorrt as trt
16-
from pycuda._driver import Error as CudaError
14+
import pycuda.driver as cuda # type: ignore # pylint: disable=import-error
15+
import tensorrt as trt # type: ignore # pylint: disable=import-error
16+
from pycuda._driver import ( # type: ignore # pylint: disable=import-error
17+
Error as CudaError,
18+
)
1719

18-
CONF_THRESH = 0.2
19-
IOU_THRESHOLD = 0.4
2020
LEN_ALL_RESULT = 38001
2121
LEN_ONE_RESULT = 38
2222

@@ -40,17 +40,24 @@ class YoLov5TRT():
4040
description: A YOLOv5 class that warps TensorRT ops, preprocess and postprocess ops.
4141
"""
4242

43-
def __init__(self, engine_file_path: str):
43+
def __init__(self, engine_file_path: str, iou_threshold: float, conf_threshold: float):
44+
logging.info('Initializing YOLOv5 TRT engine with iou_threshold: %s, conf_threshold: %s',
45+
iou_threshold, conf_threshold)
4446
# Create a Context on this device,
4547
try:
4648
cuda.init()
47-
except CudaError as e:
48-
logging.exception('cuda init error:', e)
49+
except CudaError:
50+
logging.exception('cuda init error:')
4951
self.cuda_init_error = True
5052
return
5153

5254
self.cuda_init_error = False
5355

56+
self.iou_threshold = iou_threshold
57+
"""a iou threshold to filter detections during nms"""
58+
self.conf_threshold = conf_threshold
59+
"""a confidence threshold to filter detections during nms"""
60+
5461
self.ctx = cuda.Device(0).make_context()
5562
stream = cuda.Stream()
5663
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
@@ -107,7 +114,7 @@ def check_cuda_init_error(self):
107114
def infer(self, image_raw):
108115
self.check_cuda_init_error()
109116

110-
threading.Thread.__init__(self)
117+
threading.Thread.__init__(self) # type: ignore
111118
# Make self the active context, pushing it on top of the context stack.
112119
self.ctx.push()
113120
# Restore
@@ -259,8 +266,7 @@ def _post_process(self, output, origin_h, origin_w):
259266
pred = np.reshape(output[1:], (-1, LEN_ONE_RESULT))[:num, :]
260267
pred = pred[:, :6]
261268
# Do nms
262-
boxes = self._non_max_suppression(
263-
pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
269+
boxes = self._non_max_suppression(pred, origin_h, origin_w)
264270
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
265271
result_scores = boxes[:, 4] if len(boxes) else np.array([])
266272
result_classid = boxes[:, 5] if len(boxes) else np.array([])
@@ -309,19 +315,21 @@ def bbox_iou(self, box1, box2, x1y1x2y2=True):
309315

310316
return iou
311317

312-
def _non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
318+
def _non_max_suppression(self, prediction, origin_h, origin_w):
313319
"""
314320
description: Removes detections with lower object confidence score than 'conf_thres' and performs
315321
Non-Maximum Suppression to further filter detections.
316322
param:
317323
prediction: detections, (x1, y1, x2, y2, conf, cls_id)
318324
origin_h: original image height
319325
origin_w: original image width
320-
conf_thres: a confidence threshold to filter detections
321-
nms_thres: a iou threshold to filter detections
322326
return:
323327
boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id)
324328
"""
329+
330+
conf_thres = self.conf_threshold
331+
nms_thres = self.iou_threshold
332+
325333
# Get the boxes that score > CONF_THRESH
326334
boxes = prediction[prediction[:, 4] >= conf_thres]
327335
# Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]

detector/yolov5_detector.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def __init__(self) -> None:
2323
assert self.weight_type in ['FP16', 'FP32', 'INT8'], 'WEIGHT_TYPE must be one of FP16, FP32, INT8'
2424
self.log = logging.getLogger('Yolov5Detector')
2525
self.log.setLevel(logging.INFO)
26+
self.iou_threshold = float(os.getenv('IOU_THRESHOLD', '0.45'))
27+
self.conf_threshold = float(os.getenv('CONF_THRESHOLD', '0.2'))
2628

2729
def init(self) -> None:
2830
assert self.model_info is not None, 'model_info must be set before calling init()'
@@ -37,30 +39,40 @@ def init(self) -> None:
3739
self.yolov5 = None
3840
self.log.info('destroyed old yolov5 instance')
3941

40-
self.yolov5 = yolov5.YoLov5TRT(engine_file)
42+
self.yolov5 = yolov5.YoLov5TRT(engine_file, self.iou_threshold, self.conf_threshold)
4143
for _ in range(3):
4244
warmup = yolov5.warmUpThread(self.yolov5)
4345
warmup.start()
4446
warmup.join()
4547

4648
@staticmethod
4749
def clip_box(
48-
x: float, y: float, width: float, height: float, img_width: int, img_height: int) -> Tuple[
49-
float, float, float, float]:
50-
'''make sure the box is within the image
51-
x,y is the center of the box
50+
x1: float, y1: float, width: float, height: float, img_width: int, img_height: int) -> Tuple[
51+
int, int, int, int]:
52+
'''Clips a box defined by top-left corner (x1, y1), width, and height
53+
to stay within image boundaries (img_width, img_height).
54+
Returns the clipped (x1, y1, width, height) as ints.
5255
'''
53-
left = max(0, x - 0.5 * width)
54-
top = max(0, y - 0.5 * height)
55-
right = min(img_width, x + 0.5 * width)
56-
bottom = min(img_height, y + 0.5 * height)
56+
x2 = x1 + width
57+
y2 = y1 + height
5758

58-
x = 0.5 * (left + right)
59-
y = 0.5 * (top + bottom)
60-
width = right - left
61-
height = bottom - top
59+
# Clip coordinates
60+
clipped_x1 = round(max(0.0, x1))
61+
clipped_y1 = round(max(0.0, y1))
62+
clipped_x2 = round(min(float(img_width), x2))
63+
clipped_y2 = round(min(float(img_height), y2))
6264

63-
return x, y, width, height
65+
# Recalculate dimensions
66+
clipped_width = clipped_x2 - clipped_x1
67+
clipped_height = clipped_y2 - clipped_y1
68+
69+
# Ensure width and height are non-negative
70+
if clipped_width < 0:
71+
clipped_width = 0
72+
if clipped_height < 0:
73+
clipped_height = 0
74+
75+
return clipped_x1, clipped_y1, clipped_width, clipped_height
6476

6577
@staticmethod
6678
def clip_point(x: float, y: float, img_width: int, img_height: int) -> Tuple[float, float]:
@@ -87,13 +99,13 @@ def evaluate(self, image: bytes) -> ImageMetadata:
8799
skipped_detections.append((category.name, detection))
88100
continue
89101
if category.type == CategoryType.Box:
90-
x, y, w, h = self.clip_box(x, y, w, h, im_width, im_height)
102+
clipped_x1, clipped_y1, clipped_w, clipped_h = self.clip_box(x, y, w, h, im_width, im_height)
91103
image_metadata.box_detections.append(
92104
BoxDetection(category_name=category.name,
93-
x=round(x),
94-
y=round(y),
95-
width=round(x+w)-round(x),
96-
height=round(y+h)-round(y),
105+
x=clipped_x1,
106+
y=clipped_y1,
107+
width=clipped_w,
108+
height=clipped_h,
97109
category_id=category.id,
98110
model_name=self.model_info.version,
99111
confidence=probability))

0 commit comments

Comments
 (0)