Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Yolov10 #546

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ NVIDIA DeepStream SDK 7.0 / 6.4 / 6.3 / 6.2 / 6.1.1 / 6.1 / 6.0.1 / 6.0 / 5.1 c
* [YOLOv6 usage](docs/YOLOv6.md)
* [YOLOv7 usage](docs/YOLOv7.md)
* [YOLOv8 usage](docs/YOLOv8.md)
* [YOLOv10 usage](docs/YOLOv10.md)
* [YOLOR usage](docs/YOLOR.md)
* [YOLOX usage](docs/YOLOX.md)
* [DAMO-YOLO usage](docs/DAMOYOLO.md)
Expand Down
28 changes: 28 additions & 0 deletions config_infer_primary_yoloV10.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[property]
gpu-id=0
net-scale-factor=0.0039215697906911373
model-color-format=0
onnx-file=yolov10s.onnx
model-engine-file=yolov10s.onnx_b1_gpu0_fp16.engine
#int8-calib-file=calib.table
labelfile-path=labels.txt
batch-size=1
## 0=FP32, 1=INT8, 2=FP16 mode
network-mode=0
num-detected-classes=80
interval=0
gie-unique-id=1
process-mode=1
network-type=0
## 1=DBSCAN, 2=NMS, 3= DBSCAN+NMS Hybrid, 4 = None(No clustering)
cluster-mode=4
maintain-aspect-ratio=1
symmetric-padding=1
#workspace-size=2000
parse-bbox-func-name=NvDsInferParseYoloE
#parse-bbox-func-name=NvDsInferParseYoloCuda
custom-lib-path=nvdsinfer_custom_impl_Yolo/libnvdsinfer_custom_impl_Yolo.so
#engine-create-func-name=NvDsInferYoloCudaEngineGet

[class-attrs-all]
pre-cluster-threshold=0.45
201 changes: 201 additions & 0 deletions docs/YOLOv10.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# YOLOv10 usage

**NOTE**: The yaml file is not required.

- [YOLOv10 usage](#yolov10-usage)
- [](#)
- [Convert model](#convert-model)
- [1. Download the YOLOv10 repo and install the requirements](#1-download-the-yolov10-repo-and-install-the-requirements)
- [2. Copy conversor](#2-copy-conversor)
- [3. Download the model](#3-download-the-model)
- [4. Convert model](#4-convert-model)
- [5. Copy generated files](#5-copy-generated-files)
- [](#-1)
- [Compile the lib](#compile-the-lib)
- [](#-2)
- [Edit the config\_infer\_primary\_yolov10 file](#edit-the-config_infer_primary_yolov10-file)
- [](#-3)
- [Edit the deepstream\_app\_config file](#edit-the-deepstream_app_config-file)
- [](#-4)
- [Testing the model](#testing-the-model)

##

### Convert model

#### 1. Download the YOLOv10 repo and install the requirements

```
git clone https://github.com/THU-MIG/yolov10.git
cd yolov10
pip3 install -r requirements.txt
pip install -e .
```

**NOTE**: It is recommended to use Python virtualenv.

#### 2. Copy conversor

Copy the `export_yolov10.py` file from `DeepStream-Yolo/utils` directory to the `yolov10` folder.

#### 3. Download the model

Download the `pt` file from [YOLOv10](https://github.com/THU-MIG/yolov10/releases) releases (example for YOLOv10s)

```
wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10s.pt
```

**NOTE**: You can use your custom model.

#### 4. Convert model

Generate the ONNX model file (example for YOLOv10s)

```
python3 export_yoloV10.py -w yolov10s.pt --dynamic
```

**NOTE**: To change the inference size (defaut: 640)

```
-s SIZE
--size SIZE
-s HEIGHT WIDTH
--size HEIGHT WIDTH
```

Example for 1280

```
-s 1280
```

or

```
-s 1280 1280
```

**NOTE**: To simplify the ONNX model (DeepStream >= 6.0)

```
--simplify
```

**NOTE**: To use dynamic batch-size (DeepStream >= 6.1)

```
--dynamic
```

**NOTE**: To use static batch-size (example for batch-size = 4)

```
--batch 4
```


**NOTE**: To change maximum number of Detections (example for max_det = 300 )

```
--max_det 300
```

**NOTE**: If you are using the DeepStream 5.1, remove the `--dynamic` arg and use opset 12 or lower. The default opset is 16.

```
--opset 12
```

#### 5. Copy generated files

Copy the generated ONNX model file and labels.txt file (if generated) to the `DeepStream-Yolo` folder.

##

### Compile the lib

1. Open the `DeepStream-Yolo` folder and compile the lib

2. Set the `CUDA_VER` according to your DeepStream version

```
export CUDA_VER=XY.Z
```

* x86 platform

```
DeepStream 7.0 / 6.4 = 12.2
DeepStream 6.3 = 12.1
DeepStream 6.2 = 11.8
DeepStream 6.1.1 = 11.7
DeepStream 6.1 = 11.6
DeepStream 6.0.1 / 6.0 = 11.4
DeepStream 5.1 = 11.1
```

* Jetson platform

```
DeepStream 7.0 / 6.4 = 12.2
DeepStream 6.3 / 6.2 / 6.1.1 / 6.1 = 11.4
DeepStream 6.0.1 / 6.0 / 5.1 = 10.2
```

3. Make the lib

```
make -C nvdsinfer_custom_impl_Yolo clean && make -C nvdsinfer_custom_impl_Yolo
```

##

### Edit the config_infer_primary_yolov10 file

Edit the `config_infer_primary_yolov10.txt` file according to your model (example for YOLOv10s with 80 classes)

```
[property]
...
onnx-file=yolov10s.onnx
...
num-detected-classes=80
...
parse-bbox-func-name=NvDsInferParseYoloE
...
```

**NOTE**: The **YOLOv10** resizes the input with center padding. To get better accuracy, use

```
[property]
...
maintain-aspect-ratio=1
symmetric-padding=1
...
```

##

### Edit the deepstream_app_config file

```
...
[primary-gie]
...
config-file=config_infer_primary_yolov10.txt
```

##

### Testing the model

```
deepstream-app -c deepstream_app_config.txt
```

**NOTE**: The TensorRT engine file may take a very long time to generate (sometimes more than 10 minutes).

**NOTE**: For more information about custom models configuration (`batch-size`, `network-mode`, etc), please check the [`docs/customModels.md`](customModels.md) file.
124 changes: 124 additions & 0 deletions utils/export_yoloV10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import sys
import argparse
import warnings
import onnx
import torch
import torch.nn as nn
from copy import deepcopy
from ultralytics import YOLO
from ultralytics.utils.torch_utils import select_device
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder, v10Detect


class DeepStreamOutput(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
boxes = x[:, :, :4]
scores, classes = x[:, :, 4], x[:, :, 5]
classes = classes.float()
return boxes, scores, classes


def suppress_warnings():
warnings.filterwarnings('ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)


def yolov10_export(weights, device):
model = YOLO(weights)
model = deepcopy(model.model).to(device)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.float()
model = model.fuse()
for k, m in model.named_modules():
if isinstance(m, (Detect, RTDETRDecoder)):
m.dynamic = False
m.export = True
m.format = 'onnx'
if isinstance(m, v10Detect):
m.max_det = args.max_det
elif isinstance(m, C2f):
m.forward = m.forward_split
return model


def main(args):
suppress_warnings()

print('\nStarting: %s' % args.weights)

print('Opening YOLOv10 model\n')

device = select_device('cpu')
model = yolov10_export(args.weights, device)

if len(model.names.keys()) > 0:
print('\nCreating labels.txt file')
f = open('labels.txt', 'w')
for name in model.names.values():
f.write(name + '\n')
f.close()

model = nn.Sequential(model, DeepStreamOutput())

img_size = args.size * 2 if len(args.size) == 1 else args.size

onnx_input_im = torch.zeros(args.batch, 3, *img_size).to(device)
onnx_output_file = os.path.basename(args.weights).split('.pt')[0] + '.onnx'

dynamic_axes = {
'input': {
0: 'batch'
},
'boxes': {
0: 'batch'
},
'scores': {
0: 'batch'
},
'classes': {
0: 'batch'
}
}

print('\nExporting the model to ONNX')
torch.onnx.export(model, onnx_input_im, onnx_output_file, verbose=False, opset_version=args.opset,
do_constant_folding=True, input_names=['input'], output_names=['boxes', 'scores', 'classes'],
dynamic_axes=dynamic_axes if args.dynamic else None)

if args.simplify:
print('Simplifying the ONNX model')
import onnxsim
model_onnx = onnx.load(onnx_output_file)
model_onnx, _ = onnxsim.simplify(model_onnx)
onnx.save(model_onnx, onnx_output_file)

print('Done: %s\n' % onnx_output_file)


def parse_args():
parser = argparse.ArgumentParser(description='DeepStream YOLOv10 conversion')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument('-s', '--size', nargs='+', type=int, default=[640], help='Inference size [H,W] (default [640])')
parser.add_argument('--opset', type=int, default=16, help='ONNX opset version')
parser.add_argument('--simplify', action='store_true', help='ONNX simplify model')
parser.add_argument('--dynamic', action='store_true', help='Dynamic batch-size')
parser.add_argument('--batch', type=int, default=1, help='Static batch-size')
parser.add_argument('--max_det', type=int, default=300, help='Max detections per image')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid weights file')
if args.dynamic and args.batch > 1:
raise SystemExit('Cannot set dynamic batch-size and static batch-size at same time')
return args


if __name__ == '__main__':
args = parse_args()
sys.exit(main(args))