diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 8a58746c7..3c8f0ee47 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size < 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", + SHERPA_ONNX_LOGE("trt_max_workspace_size: %lld is not valid.", trt_max_workspace_size); return false; } diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index ff9607909..fdc875e0a 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -27,7 +27,7 @@ struct CudaConfig { }; struct TensorrtConfig { - int32_t trt_max_workspace_size = 2147483647; + int64_t trt_max_workspace_size = 2147483647; int32_t trt_max_partition_iterations = 10; int32_t trt_min_subgraph_size = 5; bool trt_fp16_enable = true; @@ -39,7 +39,7 @@ struct TensorrtConfig { bool trt_dump_subgraphs = false; TensorrtConfig() = default; - TensorrtConfig(int32_t trt_max_workspace_size, + TensorrtConfig(int64_t trt_max_workspace_size, int32_t trt_max_partition_iterations, int32_t trt_min_subgraph_size, bool trt_fp16_enable, diff --git a/sherpa-onnx/python/csrc/tensorrt-config.cc b/sherpa-onnx/python/csrc/tensorrt-config.cc index 87962a2d3..ae48a945b 100644 --- a/sherpa-onnx/python/csrc/tensorrt-config.cc +++ b/sherpa-onnx/python/csrc/tensorrt-config.cc @@ -14,7 +14,7 @@ void PybindTensorrtConfig(py::module *m) { using PyClass = TensorrtConfig; py::class_(*m, "TensorrtConfig") .def(py::init<>()) - .def(py::init([](int32_t trt_max_workspace_size, + .def(py::init([](int64_t trt_max_workspace_size, int32_t trt_max_partition_iterations, int32_t trt_min_subgraph_size, bool trt_fp16_enable,