Convert PyTorch models to TensorRT via ONNX
This repository provides a tool to convert PyTorch models into TensorRT format using the ONNX intermediate representation. This conversion allows for optimized inference.
The conversion process is controlled by a configuration file (config.yaml). Below is an example configuration:
net_model: AttentionUnet3D
torch_model_path: ./work/best.pth
onnx_model_path: ./work/best.onnx
trt_model_path: ./work/best.trt
task: onnx2trt # torch2onnx or onnx2trt or torch2trt
use_verify: True
use_fp16: False
in_ch: 1
num_classes: 1
input_shape: !!python/tuple [64, 288, 288].-
Python 3.10.xx
-
PyTorch: pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
-
TensorRT==8.5.3.1
-
ONNX==1.16.1
-
ONNX Runtime==1.18.0
-
PyCUDA==2022.1+cuda116
-
NumPy==1.23.5
Must install the dependencies in whl folder
Clone the repository to your local machine:
git clone https://github.com/yourusername/ONNX-TensorRT.git
cd ONNX-TensorRTPrepare the config.yaml file with the necessary configurations. Below is an example configuration:
net_model: AttentionUnet3D
torch_model_path: ./work/best.pth
onnx_model_path: ./work/best.onnx
trt_model_path: ./work/best.trt
task: onnx2trt # torch2onnx or onnx2trt or torch2trt
use_verify: True
use_fp16: False
in_ch: 1
num_classes: 1
input_shape: !!python/tuple [64, 288, 288]- Task could be torch2onnx, onnx2trt, torch2trt
To convert a PyTorch model to ONNX format, you need to use the converter.py script. Make sure your config.yaml is set up correctly.Ensureing your task iin config.yaml is set to torch2onnx.
Run the script:
python converter.py -c path/to/config.yamlIf you have already converted your model to ONNX and want to convert it to TensorRT, ensure your task in config.yaml is set to onnx2trt. Then, run the converter.py script again:
python converter.py -c path/to/config.yamlIf use_verify is set to True in your configuration file, the conversion process will include a verification step. This step ensures that the outputs from the ONNX and TensorRT models are within a specified tolerance of the original PyTorch model.