-
-
Notifications
You must be signed in to change notification settings - Fork 148
Closed
Labels
bugSomething isn't workingSomething isn't workingduplicateThis issue or pull request already existsThis issue or pull request already existswontfixThis will not be worked onThis will not be worked on
Description
问题确认 Search before asking
Bug组件 Bug Component
No response
Bug描述 Describe the Bug
在使用Ubuntu20.04.6系统,基于cuda11.8、cudnn8.9、TensorRT8.6.1.6、OpencvCV4.11.0、FFmpeg7.1.1,执行"cmake --build build -j$(nproc) --config Release --target install"进行编译时"TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cuh"该文件的第297行附近有类型隐式转换歧义报错,报错内容:
/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cuh(297): error: more than one conversion function from "__half" to a built-in type applies:
function "__half::operator float() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(204): here
function "__half::operator short() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(222): here
function "__half::operator unsigned short() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(225): here
function "__half::operator int() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(228): here
function "__half::operator unsigned int() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(231): here
function "__half::operator long long() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(234): here
function "__half::operator unsigned long long() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(237): here
function "__half::operator __nv_bool() const"
/usr/local/cuda-11.8/targets/x86_64-linux/include/cuda_fp16.hpp(241): here
detected during:
instantiation of "float RotatedBoxCenterSize<T>::probiou(RotatedBoxCenterSize<T> &, RotatedBoxCenterSize<T> &) [with T=__half]"
(242): here
instantiation of "float RotatedBoxCorner<T>::probiou(RotatedBoxCorner<T>, RotatedBoxCorner<T>) [with T=__half]"
/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu(39): here
instantiation of "float IOU(nvinfer1::plugin::EfficientRotatedNMSParameters, RotatedBoxCorner<T>, RotatedBoxCorner<T>) [with T=__half]"
/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu(283): here
instantiation of "void EfficientRotatedNMS(nvinfer1::plugin::EfficientRotatedNMSParameters, const int *, int *, int *, const int *, const T *, const int *, const int *, const Tb *, const Tb *, int *, T *, int *, RotatedBoxCorner<T> *) [with T=__half, Tb=RotatedBoxCorner<__half>]"
/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu(316): here
instantiation of "cudaError_t EfficientRotatedNMSLauncher(nvinfer1::plugin::EfficientRotatedNMSParameters &, int *, int *, int *, int *, T *, int *, int *, const void *, const void *, int *, T *, int *, void *, cudaStream_t) [with T=__half]"
/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu(637): here
instantiation of "pluginStatus_t EfficientRotatedNMSDispatch<T>(nvinfer1::plugin::EfficientRotatedNMSParameters, const void *, const void *, const void *, void *, void *, void *, void *, void *, cudaStream_t) [with T=__half]"
/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu(659): here
4 errors detected in the compilation of "/home/heqingchun/heqingchun/YOLO/test/TensorRT-YOLO-main/modules/plugin/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu".
make[2]: *** [modules/plugin/CMakeFiles/custom_plugins.dir/build.make:151:modules/plugin/CMakeFiles/custom_plugins.dir/efficientRotatedNMSPlugin/efficientRotatedNMSInference.cu.o] 错误 1
make[2]: *** 正在等待未完成的任务....
make[1]: *** [CMakeFiles/Makefile2:196:modules/plugin/CMakeFiles/custom_plugins.dir/all] 错误 2
make: *** [Makefile:136:all] 错误 2
优化解决方案:
1.将文件efficientRotatedNMSInference.cuh开始位置"#include <cuda_fp16.h>"下方添加辅助函数
// 添加的辅助函数
device inline float to_float(float val) {
return val;
}
device inline float to_float(__half val) {
return __half2float(val);
}
2.将引起歧义的第297行附近内容修改
修改前:
float sub_x1_x2 = a.x - b.x;
float sub_y1_y2 = a.y - b.y;
修改后:
float sub_x1_x2 = to_float(a.x) - to_float(b.x);
float sub_y1_y2 = to_float(a.y) - to_float(b.y);
注:
这样可以确保:
1.当T为float时,直接进行浮点运算
2.当T为__half时,先转换为float再进行运算
为什么这样修改可以解决问题?
1.消除了编译器在__half类型转换时的歧义
2.通过函数重载保持了代码对float和__half类型的通用性
3.显式指定了转换路径,避免编译器尝试多种内置类型转换
4.保持了原有算法的数学逻辑不变
复现环境 Environment
- OS:Ubuntu20.04
- CUDA:11.8
- cudnn8.9
- TensorRT:8.6.1.6
- OpenCV4.11.0
Bug描述确认 Bug description confirmation
- 我确认已经提供了Bug复现步骤、代码改动说明、以及环境信息,确认问题是可以复现的。I confirm that the bug replication steps, code change instructions, and environment information have been provided, and the problem can be reproduced.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingduplicateThis issue or pull request already existsThis issue or pull request already existswontfixThis will not be worked onThis will not be worked on