From 3c1879d8abd851e659aefcd9290db0f06a80757f Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 6 Jun 2024 07:05:49 +0000 Subject: [PATCH 01/34] onnxrt ep configs Signed-off-by: manickavela1998@gmail.com --- .../csrc/onnxrt-execution-provider-config.cc | 169 ++++++++++++++++++ .../csrc/onnxrt-execution-provider-config.h | 88 +++++++++ 2 files changed, 257 insertions(+) create mode 100644 sherpa-onnx/csrc/onnxrt-execution-provider-config.cc create mode 100644 sherpa-onnx/csrc/onnxrt-execution-provider-config.h diff --git a/sherpa-onnx/csrc/onnxrt-execution-provider-config.cc b/sherpa-onnx/csrc/onnxrt-execution-provider-config.cc new file mode 100644 index 000000000..65cbcfe05 --- /dev/null +++ b/sherpa-onnx/csrc/onnxrt-execution-provider-config.cc @@ -0,0 +1,169 @@ +// sherpa-onnx/csrc/online-transducer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/onnxrt-execution-provider-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnnxrtCudaConfig::Register(ParseOptions *po) { + po->Register("cuda-device", &device, + "Onnxruntime CUDA device index." + "Set based on available CUDA device"); + po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, + "CuDNN convolution algrorithm search"); +} + +bool OnnxrtCudaConfig::Validate() const { + + if(device > 0) { + SHERPA_ONNX_LOGE("device: '%d' is not valid.", device); + return false; + } + + if(cudnn_conv_algo_search > 0 && cudnn_conv_algo_search < 4) { + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." + "Options : [1,3]. Check OnnxRT docs", + cudnn_conv_algo_search); + return false; + } + + return true; +} + +std::string OnnxrtCudaConfig::ToString() const { + std::ostringstream os; + + os << "OnnxrtCudaConfig("; + os << "device=\"" << device << "\", "; + os << "cudnn_conv_algo_search=\"" << cudnn_conv_algo_search << ")"; + + return os.str(); +} + +void OnnxrtTensorrtConfig::Register(ParseOptions *po) { + po->Register("device", &device, + "Onnxruntime CUDA device index." + "Set based on available CUDA device"); + po->Register("trt-max-workspace-size",&trt_max_workspace_size, + ""); + po->Register("trt-max-partition-iterations",&trt_max_partition_iterations, + ""); + po->Register("trt-min-subgraph-size ",&trt_min_subgraph_size, + ""); + po->Register("trt-fp16-enable",&trt_fp16_enable, + ""); + po->Register("trt-detailed-build-log",&trt_detailed_build_log, + ""); + po->Register("trt-engine-cache-enable",&trt_engine_cache_enable, + ""); + po->Register("trt-engine-cache-path",&trt_engine_cache_path, + ""); + po->Register("trt-timing-cache-enable",&trt_timing_cache_enable, + ""); + po->Register("trt-timing-cache-path",&trt_timing_cache_path, + ""); +} + +bool OnnxrtTensorrtConfig::Validate() const { + + if (trt_max_workspace_size > 0) { + SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.", + trt_max_workspace_size); + return false; + } + if (trt_max_partition_iterations > 0) { + SHERPA_ONNX_LOGE("trt_max_partition_iterations: '%d' is not valid.", + trt_max_partition_iterations); + return false; + } + if (trt_min_subgraph_size > 0) { + SHERPA_ONNX_LOGE("trt_min_subgraph_size: '%d' is not valid.", + trt_min_subgraph_size); + return false; + } + if (trt_fp16_enable < 0 || trt_fp16_enable > 1) { + SHERPA_ONNX_LOGE("trt_fp16_enable: '%d' is not valid.",trt_fp16_enable); + return false; + } + if (trt_detailed_build_log < 0 || trt_detailed_build_log > 1) { + SHERPA_ONNX_LOGE("trt_detailed_build_log: '%d' is not valid.", + trt_detailed_build_log); + return false; + } + if (trt_engine_cache_enable < 0 || trt_engine_cache_enable > 1) { + SHERPA_ONNX_LOGE("trt_engine_cache_enable: '%d' is not valid.", + trt_engine_cache_enable); + return false; + } + if (trt_timing_cache_enable < 0 || trt_timing_cache_enable > 1) { + SHERPA_ONNX_LOGE("trt_timing_cache_enable: '%d' is not valid.", + trt_timing_cache_enable); + return false; + } + + if(trt_max_workspace_size > 0) { + SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.",device); + return false; + } + + return true; +} + +std::string OnnxrtTensorrtConfig::ToString() const { + std::ostringstream os; + + os << "OnnxrtTensorrtConfig("; + os << "device=\"" << device << "\", "; + os << "trt_max_workspace_size=\"" << trt_max_workspace_size << "\", "; + os << "trt_max_partition_iterations=\"" << trt_max_partition_iterations << "\", "; + os << "trt_min_subgraph_size=\"" << trt_min_subgraph_size << "\", "; + os << "trt_fp16_enable=\"" << trt_fp16_enable << "\", "; + os << "trt_detailed_build_log=\"" << trt_detailed_build_log << "\", "; + os << "trt_engine_cache_enable=\"" << trt_engine_cache_enable << "\", "; + os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", "; + os << "trt_timing_cache_enable=\"" << trt_timing_cache_enable << "\", "; + os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << ")"; + + return os.str(); +} + +void OnnxrtExecutionProviderConfig::Register(ParseOptions *po) { + po->Register("device", &device, + "Onnxruntime CUDA device index." + "Set based on available CUDA device"); + po->Register("cudnn_conv_algo_search", &cudnn_conv_algo_search, "CuDNN convolution algrorithm search"); +} + +bool OnnxrtExecutionProviderConfig::Validate() const { + + if(device > 0) { + SHERPA_ONNX_LOGE("device: '%d' is not valid.", device); + return false; + } + + if(cudnn_conv_algo_search > 0 && cudnn_conv_algo_search < 4) { + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." + "Options : [1,3]. Check OnnxRT docs", + cudnn_conv_algo_search); + return false; + } + + return true; +} + +std::string OnnxrtExecutionProviderConfig::ToString() const { + std::ostringstream os; + + os << "OnnxrtCudaConfig("; + os << "device=\"" << device << "\", "; + os << "cudnn_conv_algo_search=\"" << cudnn_conv_algo_search << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnxrt-execution-provider-config.h b/sherpa-onnx/csrc/onnxrt-execution-provider-config.h new file mode 100644 index 000000000..3849946b5 --- /dev/null +++ b/sherpa-onnx/csrc/onnxrt-execution-provider-config.h @@ -0,0 +1,88 @@ +// sherpa-onnx/csrc/online-transducer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OnnxrtCudaConfig { + uint32_t device = 0; + uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; + + OnnxrtCudaConfig() = default; + OnnxrtCudaConfig(const uint32_t &device, + const uint32_t &cudnn_conv_algo_search) + : device(device), cudnn_conv_algo_search(cudnn_conv_algo_search) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct OnnxrtTensorrtConfig { + uint32_t device = 0; + uint32_t trt_max_workspace_size = 2147483648; + uint32_t trt_max_partition_iterations = 10; + uint32_t trt_min_subgraph_size = 5; + uint32_t trt_fp16_enable = 1; + uint32_t trt_detailed_build_log = 0; + uint32_t trt_engine_cache_enable = 1; + std::string trt_engine_cache_path = "."; + uint32_t trt_timing_cache_enable = 1; + std::string trt_timing_cache_path = "."; + + OnnxrtTensorrtConfig() = default; + OnnxrtTensorrtConfig(const uint32_t &device, + const uint32_t &trt_max_workspace_size, + const uint32_t &trt_max_partition_iterations, + const uint32_t &trt_min_subgraph_size, + const uint32_t &trt_fp16_enable, + const uint32_t &trt_detailed_build_log, + const uint32_t &trt_engine_cache_enable, + const std::string &trt_engine_cache_path, + const uint32_t &trt_timing_cache_enable, + const std::string &trt_timing_cache_path) + : device(device), trt_max_workspace_size(trt_max_workspace_size), + trt_max_partition_iterations(trt_max_partition_iterations), + trt_min_subgraph_size(trt_min_subgraph_size), + trt_fp16_enable(trt_fp16_enable), + trt_detailed_build_log(trt_detailed_build_log), + trt_engine_cache_enable(trt_engine_cache_enable), + trt_engine_cache_path(trt_engine_cache_path), + trt_timing_cache_enable(trt_timing_cache_enable), + trt_timing_cache_path(trt_timing_cache_path) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +struct OnnxrtExecutionProviderConfig { + std::string provider = "cpu"; + OnnxrtCudaConfig onnxrtcuda; + OnnxrtTensorrtConfig onnxrttrtconfig; + + OnnxrtExecutionProviderConfig() = default; + OnnxrtExecutionProviderConfig(const std::string &provider, + const OnnxrtCudaConfig &onnxrtcuda, + const OnnxrtTensorrtConfig &onnxrttrtconfig) + : provider(provider), onnxrtcuda(onnxrtcuda), + onnxrttrtconfig(onnxrttrtconfig) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ From e69d49540e5b4b07f54f76f3e5f3d3a33ca99d21 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Mon, 24 Jun 2024 16:54:22 +0000 Subject: [PATCH 02/34] updating config Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/c-api/c-api.cc | 4 +- sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/online-model-config.cc | 10 +- sherpa-onnx/csrc/online-model-config.h | 6 +- sherpa-onnx/csrc/online-recognizer.cc | 1 + sherpa-onnx/csrc/online-websocket-server.cc | 1 - ...-provider-config.cc => provider-config.cc} | 116 +++++++++--------- ...on-provider-config.h => provider-config.h} | 63 +++++----- sherpa-onnx/csrc/provider.h | 1 + sherpa-onnx/csrc/session.cc | 65 +++++++--- 10 files changed, 150 insertions(+), 118 deletions(-) rename sherpa-onnx/csrc/{onnxrt-execution-provider-config.cc => provider-config.cc} (50%) rename sherpa-onnx/csrc/{onnxrt-execution-provider-config.h => provider-config.h} (53%) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index ef4034ec3..5333afbb8 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( SHERPA_ONNX_OR(config->model_config.tokens, ""); recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); - recognizer_config.model_config.provider = + recognizer_config.model_config.provider_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); recognizer_config.model_config.model_type = SHERPA_ONNX_OR(config->model_config.model_type, ""); @@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter( SHERPA_ONNX_OR(config->model_config.tokens, ""); spotter_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); - spotter_config.model_config.provider = + spotter_config.model_config.provider_config.provider = SHERPA_ONNX_OR(config->model_config.provider, "cpu"); spotter_config.model_config.model_type = SHERPA_ONNX_OR(config->model_config.model_type, ""); diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 16da143f1..82e07aa24 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -87,6 +87,7 @@ set(sources packed-sequence.cc pad-sequence.cc parse-options.cc + provider-config.cc provider.cc resample.cc session.cc diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index a8efa870d..19d14c3ff 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) { wenet_ctc.Register(po); zipformer2_ctc.Register(po); nemo_ctc.Register(po); + provider_config.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("debug", &debug, "true to print model information while loading it."); - po->Register("provider", &provider, - "Specify a provider to use: cpu, cuda, coreml"); - po->Register("modeling-unit", &modeling_unit, "The modeling unit of the model, commonly used units are bpe, " "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " @@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const { return nemo_ctc.Validate(); } + if (!provider_config.Validate()) { + return provider_config.Validate(); + } + return transducer.Validate(); } @@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const { os << "wenet_ctc=" << wenet_ctc.ToString() << ", "; os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", "; os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; + os << "provider_config=" << provider_config.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "warm_up=" << warm_up << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; - os << "provider=\"" << provider << "\", "; os << "model_type=\"" << model_type << "\", "; os << "modeling_unit=\"" << modeling_unit << "\", "; os << "bpe_vocab=\"" << bpe_vocab << "\")"; diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 1509bd5b0..6c46ce6fd 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h" +#include "sherpa-onnx/csrc/provider-config.h" namespace sherpa_onnx { @@ -20,11 +21,11 @@ struct OnlineModelConfig { OnlineWenetCtcModelConfig wenet_ctc; OnlineZipformer2CtcModelConfig zipformer2_ctc; OnlineNeMoCtcModelConfig nemo_ctc; + ExecutionProviderConfig provider_config; std::string tokens; int32_t num_threads = 1; int32_t warm_up = 0; bool debug = false; - std::string provider = "cpu"; // Valid values: // - conformer, conformer transducer from icefall @@ -50,6 +51,7 @@ struct OnlineModelConfig { const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineZipformer2CtcModelConfig &zipformer2_ctc, const OnlineNeMoCtcModelConfig &nemo_ctc, + const ExecutionProviderConfig &provider_config, const std::string &tokens, int32_t num_threads, int32_t warm_up, bool debug, const std::string &provider, const std::string &model_type, @@ -60,11 +62,11 @@ struct OnlineModelConfig { wenet_ctc(wenet_ctc), zipformer2_ctc(zipformer2_ctc), nemo_ctc(nemo_ctc), + provider_config(provider_config), tokens(tokens), num_threads(num_threads), warm_up(warm_up), debug(debug), - provider(provider), model_type(model_type), modeling_unit(modeling_unit), bpe_vocab(bpe_vocab) {} diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 599a0553d..436abb82d 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -113,6 +113,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { } bool OnlineRecognizerConfig::Validate() const { + SHERPA_ONNX_LOGE("Args recognizer : %s",ToString().c_str()); if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { if (max_active_paths <= 0) { SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", diff --git a/sherpa-onnx/csrc/online-websocket-server.cc b/sherpa-onnx/csrc/online-websocket-server.cc index 6ba7a1986..2c6602740 100644 --- a/sherpa-onnx/csrc/online-websocket-server.cc +++ b/sherpa-onnx/csrc/online-websocket-server.cc @@ -69,7 +69,6 @@ int32_t main(int32_t argc, char *argv[]) { } config.Validate(); - asio::io_context io_conn; // for network connections asio::io_context io_work; // for neural network and decoding diff --git a/sherpa-onnx/csrc/onnxrt-execution-provider-config.cc b/sherpa-onnx/csrc/provider-config.cc similarity index 50% rename from sherpa-onnx/csrc/onnxrt-execution-provider-config.cc rename to sherpa-onnx/csrc/provider-config.cc index 65cbcfe05..0f0252f40 100644 --- a/sherpa-onnx/csrc/onnxrt-execution-provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -1,7 +1,7 @@ // sherpa-onnx/csrc/online-transducer-model-config.cc // // Copyright (c) 2023 Xiaomi Corporation -#include "sherpa-onnx/csrc/onnxrt-execution-provider-config.h" +#include "sherpa-onnx/csrc/provider-config.h" #include @@ -10,20 +10,12 @@ namespace sherpa_onnx { -void OnnxrtCudaConfig::Register(ParseOptions *po) { - po->Register("cuda-device", &device, - "Onnxruntime CUDA device index." - "Set based on available CUDA device"); +void CudaConfig::Register(ParseOptions *po) { po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search, "CuDNN convolution algrorithm search"); } -bool OnnxrtCudaConfig::Validate() const { - - if(device > 0) { - SHERPA_ONNX_LOGE("device: '%d' is not valid.", device); - return false; - } +bool CudaConfig::Validate() const { if(cudnn_conv_algo_search > 0 && cudnn_conv_algo_search < 4) { SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." @@ -35,20 +27,16 @@ bool OnnxrtCudaConfig::Validate() const { return true; } -std::string OnnxrtCudaConfig::ToString() const { +std::string CudaConfig::ToString() const { std::ostringstream os; - os << "OnnxrtCudaConfig("; - os << "device=\"" << device << "\", "; + os << "CudaConfig("; os << "cudnn_conv_algo_search=\"" << cudnn_conv_algo_search << ")"; return os.str(); } -void OnnxrtTensorrtConfig::Register(ParseOptions *po) { - po->Register("device", &device, - "Onnxruntime CUDA device index." - "Set based on available CUDA device"); +void TensorrtConfig::Register(ParseOptions *po) { po->Register("trt-max-workspace-size",&trt_max_workspace_size, ""); po->Register("trt-max-partition-iterations",&trt_max_partition_iterations, @@ -56,20 +44,22 @@ void OnnxrtTensorrtConfig::Register(ParseOptions *po) { po->Register("trt-min-subgraph-size ",&trt_min_subgraph_size, ""); po->Register("trt-fp16-enable",&trt_fp16_enable, - ""); + "true to enable fp16"); po->Register("trt-detailed-build-log",&trt_detailed_build_log, - ""); + "true to print TensorRT build logs"); po->Register("trt-engine-cache-enable",&trt_engine_cache_enable, - ""); + "true to enable engine caching"); po->Register("trt-engine-cache-path",&trt_engine_cache_path, ""); po->Register("trt-timing-cache-enable",&trt_timing_cache_enable, - ""); + "true to enable timing cache"); po->Register("trt-timing-cache-path",&trt_timing_cache_path, ""); + po->Register("trt-dump-subgraphs",&trt_dump_subgraphs, + "true to dump subgraphs"); } -bool OnnxrtTensorrtConfig::Validate() const { +bool TensorrtConfig::Validate() const { if (trt_max_workspace_size > 0) { SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.", @@ -86,83 +76,89 @@ bool OnnxrtTensorrtConfig::Validate() const { trt_min_subgraph_size); return false; } - if (trt_fp16_enable < 0 || trt_fp16_enable > 1) { + if (trt_fp16_enable != true || trt_fp16_enable != false) { SHERPA_ONNX_LOGE("trt_fp16_enable: '%d' is not valid.",trt_fp16_enable); return false; } - if (trt_detailed_build_log < 0 || trt_detailed_build_log > 1) { + if (trt_detailed_build_log != true || trt_detailed_build_log != false) { SHERPA_ONNX_LOGE("trt_detailed_build_log: '%d' is not valid.", trt_detailed_build_log); return false; } - if (trt_engine_cache_enable < 0 || trt_engine_cache_enable > 1) { + if (trt_engine_cache_enable != true || trt_engine_cache_enable != false) { SHERPA_ONNX_LOGE("trt_engine_cache_enable: '%d' is not valid.", trt_engine_cache_enable); return false; } - if (trt_timing_cache_enable < 0 || trt_timing_cache_enable > 1) { + if (trt_timing_cache_enable != true || trt_timing_cache_enable != false) { SHERPA_ONNX_LOGE("trt_timing_cache_enable: '%d' is not valid.", trt_timing_cache_enable); return false; } - if(trt_max_workspace_size > 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.",device); + if (trt_dump_subgraphs != true || trt_dump_subgraphs != false) { + SHERPA_ONNX_LOGE("trt_dump_subgraphs: '%d' is not valid.", + trt_dump_subgraphs); return false; } + // if(trt_max_workspace_size > 0) { + // SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.",); + // return false; + // } + return true; } -std::string OnnxrtTensorrtConfig::ToString() const { +std::string TensorrtConfig::ToString() const { std::ostringstream os; - os << "OnnxrtTensorrtConfig("; - os << "device=\"" << device << "\", "; + os << "TensorrtConfig("; os << "trt_max_workspace_size=\"" << trt_max_workspace_size << "\", "; - os << "trt_max_partition_iterations=\"" << trt_max_partition_iterations << "\", "; + os << "trt_max_partition_iterations=\"" + << trt_max_partition_iterations << "\", "; os << "trt_min_subgraph_size=\"" << trt_min_subgraph_size << "\", "; - os << "trt_fp16_enable=\"" << trt_fp16_enable << "\", "; - os << "trt_detailed_build_log=\"" << trt_detailed_build_log << "\", "; - os << "trt_engine_cache_enable=\"" << trt_engine_cache_enable << "\", "; - os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", "; - os << "trt_timing_cache_enable=\"" << trt_timing_cache_enable << "\", "; - os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << ")"; - + os << "trt_fp16_enable=\"" + << (trt_fp16_enable? "True" : "False") << "\", "; + os << "trt_detailed_build_log=\"" + << (trt_detailed_build_log? "True" : "False") << "\", "; + os << "trt_engine_cache_enable=\"" + << (trt_engine_cache_enable? "True" : "False") << "\", "; + os << "trt_engine_cache_path=\"" + << trt_engine_cache_path.c_str() << "\", "; + os << "trt_timing_cache_enable=\"" + << (trt_timing_cache_enable? "True" : "False") << "\", "; + os << "trt_timing_cache_path=\"" + << trt_timing_cache_path.c_str() << "\","; + os << "trt_dump_subgraphs=\"" + << (trt_dump_subgraphs? "True" : "False") << "\" )"; return os.str(); } -void OnnxrtExecutionProviderConfig::Register(ParseOptions *po) { - po->Register("device", &device, - "Onnxruntime CUDA device index." - "Set based on available CUDA device"); - po->Register("cudnn_conv_algo_search", &cudnn_conv_algo_search, "CuDNN convolution algrorithm search"); +void ExecutionProviderConfig::Register(ParseOptions *po) { + po->Register("device_id", &device_id, "GPU device_id for CUDA and Trt EP"); + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); } -bool OnnxrtExecutionProviderConfig::Validate() const { +bool ExecutionProviderConfig::Validate() const { - if(device > 0) { - SHERPA_ONNX_LOGE("device: '%d' is not valid.", device); - return false; - } - - if(cudnn_conv_algo_search > 0 && cudnn_conv_algo_search < 4) { - SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." - "Options : [1,3]. Check OnnxRT docs", - cudnn_conv_algo_search); + if(device_id < 0) { + SHERPA_ONNX_LOGE("device_id: '%d' is invalid.",device_id); return false; } return true; } -std::string OnnxrtExecutionProviderConfig::ToString() const { +std::string ExecutionProviderConfig::ToString() const { std::ostringstream os; - os << "OnnxrtCudaConfig("; - os << "device=\"" << device << "\", "; - os << "cudnn_conv_algo_search=\"" << cudnn_conv_algo_search << ")"; - + os << "ExecutionProviderConfig("; + os << "device_id=\"" << device_id << "\", "; + os << "provider=\"" << provider << "\", "; + os << "cuda_config=\"" << cuda_config.ToString() << "\", "; + os << "trt_config=\"" << trt_config.ToString() << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/onnxrt-execution-provider-config.h b/sherpa-onnx/csrc/provider-config.h similarity index 53% rename from sherpa-onnx/csrc/onnxrt-execution-provider-config.h rename to sherpa-onnx/csrc/provider-config.h index 3849946b5..77ef49e56 100644 --- a/sherpa-onnx/csrc/onnxrt-execution-provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -7,18 +7,17 @@ #include #include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/macros.h" #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { -struct OnnxrtCudaConfig { - uint32_t device = 0; - uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; +struct CudaConfig { + int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; - OnnxrtCudaConfig() = default; - OnnxrtCudaConfig(const uint32_t &device, - const uint32_t &cudnn_conv_algo_search) - : device(device), cudnn_conv_algo_search(cudnn_conv_algo_search) {} + CudaConfig() = default; + CudaConfig(const uint32_t &cudnn_conv_algo_search) + : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); bool Validate() const; @@ -26,30 +25,29 @@ struct OnnxrtCudaConfig { std::string ToString() const; }; -struct OnnxrtTensorrtConfig { - uint32_t device = 0; +struct TensorrtConfig { uint32_t trt_max_workspace_size = 2147483648; uint32_t trt_max_partition_iterations = 10; uint32_t trt_min_subgraph_size = 5; - uint32_t trt_fp16_enable = 1; - uint32_t trt_detailed_build_log = 0; - uint32_t trt_engine_cache_enable = 1; + bool trt_fp16_enable = 1; + bool trt_detailed_build_log = 0; + bool trt_engine_cache_enable = 1; std::string trt_engine_cache_path = "."; - uint32_t trt_timing_cache_enable = 1; + bool trt_timing_cache_enable = 1; std::string trt_timing_cache_path = "."; + bool trt_dump_subgraphs = 0; - OnnxrtTensorrtConfig() = default; - OnnxrtTensorrtConfig(const uint32_t &device, - const uint32_t &trt_max_workspace_size, + TensorrtConfig() = default; + TensorrtConfig(const uint32_t &trt_max_workspace_size, const uint32_t &trt_max_partition_iterations, const uint32_t &trt_min_subgraph_size, - const uint32_t &trt_fp16_enable, - const uint32_t &trt_detailed_build_log, - const uint32_t &trt_engine_cache_enable, + const bool &trt_fp16_enable, + const bool &trt_detailed_build_log, + const bool &trt_engine_cache_enable, const std::string &trt_engine_cache_path, - const uint32_t &trt_timing_cache_enable, + const bool &trt_timing_cache_enable, const std::string &trt_timing_cache_path) - : device(device), trt_max_workspace_size(trt_max_workspace_size), + : trt_max_workspace_size(trt_max_workspace_size), trt_max_partition_iterations(trt_max_partition_iterations), trt_min_subgraph_size(trt_min_subgraph_size), trt_fp16_enable(trt_fp16_enable), @@ -65,17 +63,20 @@ struct OnnxrtTensorrtConfig { std::string ToString() const; }; -struct OnnxrtExecutionProviderConfig { +struct ExecutionProviderConfig { + uint32_t device_id = 0; + // device_id only used for cuda and trt std::string provider = "cpu"; - OnnxrtCudaConfig onnxrtcuda; - OnnxrtTensorrtConfig onnxrttrtconfig; - - OnnxrtExecutionProviderConfig() = default; - OnnxrtExecutionProviderConfig(const std::string &provider, - const OnnxrtCudaConfig &onnxrtcuda, - const OnnxrtTensorrtConfig &onnxrttrtconfig) - : provider(provider), onnxrtcuda(onnxrtcuda), - onnxrttrtconfig(onnxrttrtconfig) {} + CudaConfig cuda_config; + TensorrtConfig trt_config; + + ExecutionProviderConfig() = default; + ExecutionProviderConfig(const uint32_t &device_id, + const std::string &provider, + CudaConfig &cuda_config, + TensorrtConfig &trt_config) + : device_id(device_id), provider(provider), cuda_config(cuda_config), + trt_config(trt_config) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index c104d401a..712006f2b 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -7,6 +7,7 @@ #include +#include "sherpa-onnx/csrc/provider-config.h" namespace sherpa_onnx { // Please refer to diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 0f6ed89db..276eee76b 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -32,11 +32,13 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { } static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, - std::string provider_str) { + std::string provider_str, + const ExecutionProviderConfig *provider_config=nullptr) { Provider p = StringToProvider(std::move(provider_str)); Ort::SessionOptions sess_opts; sess_opts.SetIntraOpNumThreads(num_threads); + sess_opts.SetInterOpNumThreads(num_threads); std::vector available_providers = Ort::GetAvailableProviders(); @@ -64,26 +66,44 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, break; } case Provider::kTRT: { + if(provider_config == nullptr) { + SHERPA_ONNX_LOGE("Tensorrt support for Online models ony," + "Must be extended for offline and others"); + exit(1); + } + auto trt_config = provider_config->trt_config; struct TrtPairs { const char *op_keys; const char *op_values; }; std::vector trt_options = { - {"device_id", "0"}, - {"trt_max_workspace_size", "2147483648"}, - {"trt_max_partition_iterations", "10"}, - {"trt_min_subgraph_size", "5"}, - {"trt_fp16_enable", "0"}, - {"trt_detailed_build_log", "0"}, - {"trt_engine_cache_enable", "1"}, - {"trt_engine_cache_path", "."}, - {"trt_timing_cache_enable", "1"}, - {"trt_timing_cache_path", "."}}; + {"device_id", + std::to_string(provider_config->device_id).c_str()}, + {"trt_max_workspace_size", + std::to_string(trt_config.trt_max_workspace_size).c_str()}, + {"trt_max_partition_iterations", + std::to_string(trt_config.trt_max_partition_iterations).c_str()}, + {"trt_min_subgraph_size", + std::to_string(trt_config.trt_min_subgraph_size).c_str()}, + {"trt_fp16_enable", + std::to_string(trt_config.trt_fp16_enable).c_str()}, + {"trt_detailed_build_log", + std::to_string(trt_config.trt_detailed_build_log).c_str()}, + {"trt_engine_cache_enable", + std::to_string(trt_config.trt_engine_cache_enable).c_str()}, + {"trt_engine_cache_path", + trt_config.trt_engine_cache_path.c_str()}, + {"trt_timing_cache_enable", + std::to_string(trt_config.trt_timing_cache_enable).c_str()}, + {"trt_timing_cache_path", + trt_config.trt_timing_cache_path.c_str()}, + {"trt_dump_subgraphs", + std::to_string(trt_config.trt_dump_subgraphs).c_str()} + }; // ToDo : Trt configs // "trt_int8_enable" // "trt_int8_use_native_calibration_table" - // "trt_dump_subgraphs" std::vector option_keys, option_values; for (const TrtPairs &pair : trt_options) { @@ -122,10 +142,18 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, "CUDAExecutionProvider") != available_providers.end()) { // The CUDA provider is available, proceed with setting the options OrtCUDAProviderOptions options; - options.device_id = 0; - // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow - options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; - // set more options on need + + if(provider_config != nullptr) { + options.device_id = provider_config->device_id; + options.cudnn_conv_algo_search = + OrtCudnnConvAlgoSearch(provider_config->cuda_config + .cudnn_conv_algo_search); + } else { + options.device_id = 0; + // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; + // set more options on need + } sess_opts.AppendExecutionProvider_CUDA(options); } else { SHERPA_ONNX_LOGE( @@ -184,11 +212,12 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, } Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); + return GetSessionOptionsImpl(config.num_threads, + config.provider_config.provider,&config.provider_config); } Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads, config.provider); + return GetSessionOptionsImpl(config.num_threads,config.provider); } Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { From 47cdf993b9b030a8d342542d1abbd513223a82e7 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Mon, 24 Jun 2024 17:28:04 +0000 Subject: [PATCH 03/34] updating python api Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/csrc/online-model-config.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index d6db809bd..b250d7ef2 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -8,6 +8,7 @@ #include #include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/provider-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" @@ -23,6 +24,7 @@ void PybindOnlineModelConfig(py::module *m) { PybindOnlineWenetCtcModelConfig(m); PybindOnlineZipformer2CtcModelConfig(m); PybindOnlineNeMoCtcModelConfig(m); + PybindProviderconfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") @@ -30,16 +32,18 @@ void PybindOnlineModelConfig(py::module *m) { const OnlineParaformerModelConfig &, const OnlineWenetCtcModelConfig &, const OnlineZipformer2CtcModelConfig &, - const OnlineNeMoCtcModelConfig &, const std::string &, - int32_t, int32_t, bool, const std::string &, - const std::string &, const std::string &, - const std::string &>(), + const OnlineNeMoCtcModelConfig &, + const ExecutionProviderConfig &, + const std::string &, int32_t, int32_t, + bool, const std::string &, const std::string &, + const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), - py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("tokens"), - py::arg("num_threads"), py::arg("warm_up") = 0, + py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), + py::arg("provider_config") = ExecutionProviderConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "", py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") @@ -51,7 +55,6 @@ void PybindOnlineModelConfig(py::module *m) { .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) - .def_readwrite("provider", &PyClass::provider) .def_readwrite("model_type", &PyClass::model_type) .def_readwrite("modeling_unit", &PyClass::modeling_unit) .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) From e29ff24f7b5fdafb299ef6162e216742cc103945 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Mon, 24 Jun 2024 18:00:30 +0000 Subject: [PATCH 04/34] attempting fixing pybind Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/csrc/online-model-config.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index b250d7ef2..1ad5974a2 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -24,7 +24,7 @@ void PybindOnlineModelConfig(py::module *m) { PybindOnlineWenetCtcModelConfig(m); PybindOnlineZipformer2CtcModelConfig(m); PybindOnlineNeMoCtcModelConfig(m); - PybindProviderconfig(m); + PybindExecutionProviderConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") From d139c1db88580f13ed65464b05958aa9354cb1ec Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Mon, 24 Jun 2024 18:50:15 +0000 Subject: [PATCH 05/34] patch for pybind and clean Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/online-model-config.h | 4 +- sherpa-onnx/csrc/provider-config.cc | 16 +++--- sherpa-onnx/csrc/provider-config.h | 14 ++--- sherpa-onnx/csrc/session.cc | 6 +-- .../python/csrc/online-model-config.cc | 6 +-- sherpa-onnx/python/csrc/provider-config.cc | 54 +++++++++++++++++++ sherpa-onnx/python/csrc/provider-config.h | 16 ++++++ 7 files changed, 93 insertions(+), 23 deletions(-) create mode 100644 sherpa-onnx/python/csrc/provider-config.cc create mode 100644 sherpa-onnx/python/csrc/provider-config.h diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 6c46ce6fd..6a88376a1 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -21,7 +21,7 @@ struct OnlineModelConfig { OnlineWenetCtcModelConfig wenet_ctc; OnlineZipformer2CtcModelConfig zipformer2_ctc; OnlineNeMoCtcModelConfig nemo_ctc; - ExecutionProviderConfig provider_config; + ProviderConfig provider_config; std::string tokens; int32_t num_threads = 1; int32_t warm_up = 0; @@ -51,7 +51,7 @@ struct OnlineModelConfig { const OnlineWenetCtcModelConfig &wenet_ctc, const OnlineZipformer2CtcModelConfig &zipformer2_ctc, const OnlineNeMoCtcModelConfig &nemo_ctc, - const ExecutionProviderConfig &provider_config, + const ProviderConfig &provider_config, const std::string &tokens, int32_t num_threads, int32_t warm_up, bool debug, const std::string &provider, const std::string &model_type, diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 0f0252f40..292ec18f7 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -135,27 +135,27 @@ std::string TensorrtConfig::ToString() const { return os.str(); } -void ExecutionProviderConfig::Register(ParseOptions *po) { - po->Register("device_id", &device_id, "GPU device_id for CUDA and Trt EP"); +void ProviderConfig::Register(ParseOptions *po) { + po->Register("device", &device, "GPU device index for CUDA and Trt EP"); po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); } -bool ExecutionProviderConfig::Validate() const { +bool ProviderConfig::Validate() const { - if(device_id < 0) { - SHERPA_ONNX_LOGE("device_id: '%d' is invalid.",device_id); + if(device < 0) { + SHERPA_ONNX_LOGE("device: '%d' is invalid.",device); return false; } return true; } -std::string ExecutionProviderConfig::ToString() const { +std::string ProviderConfig::ToString() const { std::ostringstream os; - os << "ExecutionProviderConfig("; - os << "device_id=\"" << device_id << "\", "; + os << "ProviderConfig("; + os << "device=\"" << device << "\", "; os << "provider=\"" << provider << "\", "; os << "cuda_config=\"" << cuda_config.ToString() << "\", "; os << "trt_config=\"" << trt_config.ToString() << ")"; diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 77ef49e56..ecf7f644a 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -13,7 +13,7 @@ namespace sherpa_onnx { struct CudaConfig { - int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; + uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; CudaConfig(const uint32_t &cudnn_conv_algo_search) @@ -63,19 +63,19 @@ struct TensorrtConfig { std::string ToString() const; }; -struct ExecutionProviderConfig { - uint32_t device_id = 0; - // device_id only used for cuda and trt +struct ProviderConfig { + uint32_t device = 0; + // device only used for cuda and trt std::string provider = "cpu"; CudaConfig cuda_config; TensorrtConfig trt_config; - ExecutionProviderConfig() = default; - ExecutionProviderConfig(const uint32_t &device_id, + ProviderConfig() = default; + ProviderConfig(const uint32_t &device, const std::string &provider, CudaConfig &cuda_config, TensorrtConfig &trt_config) - : device_id(device_id), provider(provider), cuda_config(cuda_config), + : device(device), provider(provider), cuda_config(cuda_config), trt_config(trt_config) {} void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 276eee76b..620b27158 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -33,7 +33,7 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::string provider_str, - const ExecutionProviderConfig *provider_config=nullptr) { + const ProviderConfig *provider_config=nullptr) { Provider p = StringToProvider(std::move(provider_str)); Ort::SessionOptions sess_opts; @@ -79,7 +79,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::vector trt_options = { {"device_id", - std::to_string(provider_config->device_id).c_str()}, + std::to_string(provider_config->device).c_str()}, {"trt_max_workspace_size", std::to_string(trt_config.trt_max_workspace_size).c_str()}, {"trt_max_partition_iterations", @@ -144,7 +144,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, OrtCUDAProviderOptions options; if(provider_config != nullptr) { - options.device_id = provider_config->device_id; + options.device_id = provider_config->device; options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(provider_config->cuda_config .cudnn_conv_algo_search); diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 1ad5974a2..982601805 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -24,7 +24,7 @@ void PybindOnlineModelConfig(py::module *m) { PybindOnlineWenetCtcModelConfig(m); PybindOnlineZipformer2CtcModelConfig(m); PybindOnlineNeMoCtcModelConfig(m); - PybindExecutionProviderConfig(m); + PybindProviderConfig(m); using PyClass = OnlineModelConfig; py::class_(*m, "OnlineModelConfig") @@ -33,7 +33,7 @@ void PybindOnlineModelConfig(py::module *m) { const OnlineWenetCtcModelConfig &, const OnlineZipformer2CtcModelConfig &, const OnlineNeMoCtcModelConfig &, - const ExecutionProviderConfig &, + const ProviderConfig &, const std::string &, int32_t, int32_t, bool, const std::string &, const std::string &, const std::string &, const std::string &>(), @@ -42,7 +42,7 @@ void PybindOnlineModelConfig(py::module *m) { py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), - py::arg("provider_config") = ExecutionProviderConfig(), + py::arg("provider_config") = ProviderConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "", py::arg("modeling_unit") = "", diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc new file mode 100644 index 000000000..0e94ed292 --- /dev/null +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -0,0 +1,54 @@ +// sherpa-onnx/python/csrc/provider-config.h +// +// Copyright (c) 2024 Uniphore Pvt Ltd(github.com/manickavela29) + + +#include "sherpa-onnx/csrc/provider-config.h" + +#include + +#include "sherpa-onnx/python/csrc/provider-config.h" + +namespace sherpa_onnx { + +void PybindCudaConfig(py::module *m) { + using PyClass = PybindCudaConfig; + py::class_(*m, "PybindCudaConfig") + .def(py::init(*m, "PybindTensorrtConfig") + .def(py::init(), + py::arg("trt_config") = TensorrtConfig(), + py::arg("cuda_config") = CudaConfig(), + py::arg("provider") = "cpu", + py::arg("device") = 0) + .def_readwrite("trt_config", &PyClass::Ten) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("joiner", &PyClass::joiner) + .def("__str__", &PyClass::ToString); +} + + +void PybindProviderConfig(py::module *m) { + using PyClass = ProviderConfig; + py::class_(*m, "ProviderConfig") + .def(py::init(), + py::arg("trt_config") = TensorrtConfig(), + py::arg("cuda_config") = CudaConfig(), + py::arg("provider") = "cpu", + py::arg("device") = 0) + .def_readwrite("trt_config", &PyClass::Ten) + .def_readwrite("decoder", &PyClass::decoder) + .def_readwrite("joiner", &PyClass::joiner) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/provider-config.h b/sherpa-onnx/python/csrc/provider-config.h new file mode 100644 index 000000000..ea0d94e5a --- /dev/null +++ b/sherpa-onnx/python/csrc/provider-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/provider-config.h +// +// Copyright (c) 2024 Uniphore Pvt Ltd(github.com/manickavela29) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindProviderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ From 79f45ad5732e0624886d44f9217a8d852a2e1a31 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Tue, 25 Jun 2024 14:54:24 +0000 Subject: [PATCH 06/34] clean up Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.cc | 5 ----- sherpa-onnx/csrc/provider-config.h | 15 ++++++++------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 292ec18f7..d3ae5d99d 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -102,11 +102,6 @@ bool TensorrtConfig::Validate() const { return false; } - // if(trt_max_workspace_size > 0) { - // SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.",); - // return false; - // } - return true; } diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index ecf7f644a..014f518d3 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -32,8 +32,8 @@ struct TensorrtConfig { bool trt_fp16_enable = 1; bool trt_detailed_build_log = 0; bool trt_engine_cache_enable = 1; - std::string trt_engine_cache_path = "."; bool trt_timing_cache_enable = 1; + std::string trt_engine_cache_path = "."; std::string trt_timing_cache_path = "."; bool trt_dump_subgraphs = 0; @@ -44,8 +44,8 @@ struct TensorrtConfig { const bool &trt_fp16_enable, const bool &trt_detailed_build_log, const bool &trt_engine_cache_enable, - const std::string &trt_engine_cache_path, const bool &trt_timing_cache_enable, + const std::string &trt_engine_cache_path, const std::string &trt_timing_cache_path) : trt_max_workspace_size(trt_max_workspace_size), trt_max_partition_iterations(trt_max_partition_iterations), @@ -53,8 +53,8 @@ struct TensorrtConfig { trt_fp16_enable(trt_fp16_enable), trt_detailed_build_log(trt_detailed_build_log), trt_engine_cache_enable(trt_engine_cache_enable), - trt_engine_cache_path(trt_engine_cache_path), trt_timing_cache_enable(trt_timing_cache_enable), + trt_engine_cache_path(trt_engine_cache_path), trt_timing_cache_path(trt_timing_cache_path) {} void Register(ParseOptions *po); @@ -72,10 +72,11 @@ struct ProviderConfig { ProviderConfig() = default; ProviderConfig(const uint32_t &device, - const std::string &provider, - CudaConfig &cuda_config, - TensorrtConfig &trt_config) - : device(device), provider(provider), cuda_config(cuda_config), + const std::string &provider, + CudaConfig &cuda_config, + TensorrtConfig &trt_config) + : device(device), provider(provider), + cuda_config(cuda_config), trt_config(trt_config) {} void Register(ParseOptions *po); From 207f4a16b8e7e506589e67e55f6baf7b2ec1bcbb Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Tue, 25 Jun 2024 15:39:25 +0000 Subject: [PATCH 07/34] python api complete Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/online-websocket-server.cc | 1 + sherpa-onnx/csrc/provider-config.cc | 21 +++---- sherpa-onnx/csrc/provider-config.h | 8 ++- sherpa-onnx/python/csrc/provider-config.cc | 66 ++++++++++++++++----- 4 files changed, 67 insertions(+), 29 deletions(-) diff --git a/sherpa-onnx/csrc/online-websocket-server.cc b/sherpa-onnx/csrc/online-websocket-server.cc index 2c6602740..6ba7a1986 100644 --- a/sherpa-onnx/csrc/online-websocket-server.cc +++ b/sherpa-onnx/csrc/online-websocket-server.cc @@ -69,6 +69,7 @@ int32_t main(int32_t argc, char *argv[]) { } config.Validate(); + asio::io_context io_conn; // for network connections asio::io_context io_work; // for neural network and decoding diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index d3ae5d99d..89c86f9ef 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -1,6 +1,7 @@ -// sherpa-onnx/csrc/online-transducer-model-config.cc +// sherpa-onnx/csrc/provider-config.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Uniphore (Author: Manickavela) + #include "sherpa-onnx/csrc/provider-config.h" #include @@ -110,23 +111,23 @@ std::string TensorrtConfig::ToString() const { os << "TensorrtConfig("; os << "trt_max_workspace_size=\"" << trt_max_workspace_size << "\", "; - os << "trt_max_partition_iterations=\"" + os << "trt_max_partition_iterations=\"" << trt_max_partition_iterations << "\", "; os << "trt_min_subgraph_size=\"" << trt_min_subgraph_size << "\", "; - os << "trt_fp16_enable=\"" + os << "trt_fp16_enable=\"" << (trt_fp16_enable? "True" : "False") << "\", "; - os << "trt_detailed_build_log=\"" + os << "trt_detailed_build_log=\"" << (trt_detailed_build_log? "True" : "False") << "\", "; - os << "trt_engine_cache_enable=\"" + os << "trt_engine_cache_enable=\"" << (trt_engine_cache_enable? "True" : "False") << "\", "; - os << "trt_engine_cache_path=\"" + os << "trt_engine_cache_path=\"" << trt_engine_cache_path.c_str() << "\", "; - os << "trt_timing_cache_enable=\"" + os << "trt_timing_cache_enable=\"" << (trt_timing_cache_enable? "True" : "False") << "\", "; os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << "\","; - os << "trt_dump_subgraphs=\"" - << (trt_dump_subgraphs? "True" : "False") << "\" )"; + os << "trt_dump_subgraphs=\"" + << (trt_dump_subgraphs? "True" : "False") << "\" )"; return os.str(); } diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 014f518d3..9343e775d 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -1,6 +1,7 @@ -// sherpa-onnx/csrc/online-transducer-model-config.h +// sherpa-onnx/csrc/provider-config.h // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2024 Uniphore (Author: Manickavela) + #ifndef SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ #define SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ @@ -55,7 +56,8 @@ struct TensorrtConfig { trt_engine_cache_enable(trt_engine_cache_enable), trt_timing_cache_enable(trt_timing_cache_enable), trt_engine_cache_path(trt_engine_cache_path), - trt_timing_cache_path(trt_timing_cache_path) {} + trt_timing_cache_path(trt_timing_cache_path, + trt_dump_subgraphs(trt_dump_subgraphs)) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 0e94ed292..051856d81 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -1,6 +1,6 @@ // sherpa-onnx/python/csrc/provider-config.h // -// Copyright (c) 2024 Uniphore Pvt Ltd(github.com/manickavela29) +// Copyright (c) 2024 Uniphore (Author: Manickavela) #include "sherpa-onnx/csrc/provider-config.h" @@ -14,25 +14,59 @@ namespace sherpa_onnx { void PybindCudaConfig(py::module *m) { using PyClass = PybindCudaConfig; py::class_(*m, "PybindCudaConfig") - .def(py::init(), py::arg("cudnn_conv_algo_search") = 1) - .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) - .def("__str__", &PyClass::ToString); + .def_readwrite("cudnn_conv_algo_search", + &PyClass::cudnn_conv_algo_search) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); } void PybindTensorrtConfig(py::module *m) { using PyClass = PybindTensorrtConfig; py::class_(*m, "PybindTensorrtConfig") - .def(py::init(), - py::arg("trt_config") = TensorrtConfig(), - py::arg("cuda_config") = CudaConfig(), - py::arg("provider") = "cpu", - py::arg("device") = 0) - .def_readwrite("trt_config", &PyClass::Ten) - .def_readwrite("decoder", &PyClass::decoder) - .def_readwrite("joiner", &PyClass::joiner) - .def("__str__", &PyClass::ToString); + .def(py::init(), + py::arg("trt_max_workspace_size") = 2147483648, + py::arg("trt_max_partition_iterations") = 10, + py::arg("trt_min_subgraph_size") = 5, + py::arg("trt_fp16_enable") = 1, + py::arg("trt_detailed_build_log") = 0, + py::arg("trt_engine_cache_enable") = 1, + py::arg("trt_timing_cache_enable") = 1, + py::arg("trt_engine_cache_path") = ".", + py::arg("trt_timing_cache_path") = ".", + py::arg("trt_dump_subgraphs") = 0) + .def_readwrite("trt_max_workspace_size", + &PyClass::trt_max_workspace_size) + .def_readwrite("trt_max_partition_iterations", + &PyClass::trt_max_partition_iterations) + .def_readwrite("trt_min_subgraph_size", + &PyClass::trt_min_subgraph_size) + .def_readwrite("trt_fp16_enable", + &PyClass::trt_fp16_enable) + .def_readwrite("trt_detailed_build_log", + &PyClass::trt_detailed_build_log) + .def_readwrite("trt_engine_cache_enable", + &PyClass::trt_engine_cache_enable) + .def_readwrite("trt_timing_cache_enable", + &PyClass::trt_timing_cache_enable) + .def_readwrite("trt_engine_cache_path", + &PyClass::trt_engine_cache_path) + .def_readwrite("trt_timing_cache_path", + &PyClass::trt_timing_cache_path) + .def_readwrite("trt_dump_subgraphs", + &PyClass::trt_dump_subgraphs) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); } @@ -48,7 +82,7 @@ void PybindProviderConfig(py::module *m) { .def_readwrite("trt_config", &PyClass::Ten) .def_readwrite("decoder", &PyClass::decoder) .def_readwrite("joiner", &PyClass::joiner) - .def("__str__", &PyClass::ToString); -} + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate);} } // namespace sherpa_onnx From df2e2e11e21460fcc38c40bd1d1c3ced8d312692 Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:00:46 +0530 Subject: [PATCH 08/34] Apply suggestions from code review Co-authored-by: Fangjun Kuang --- sherpa-onnx/csrc/provider-config.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 89c86f9ef..4d61eac9d 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -19,7 +19,7 @@ void CudaConfig::Register(ParseOptions *po) { bool CudaConfig::Validate() const { if(cudnn_conv_algo_search > 0 && cudnn_conv_algo_search < 4) { - SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." "Options : [1,3]. Check OnnxRT docs", cudnn_conv_algo_search); return false; From c477339a7808f184983db0aaac04609afae8a034 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 03:38:28 +0000 Subject: [PATCH 09/34] fix and update Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/online-model-config.cc | 2 +- sherpa-onnx/csrc/online-model-config.h | 2 +- sherpa-onnx/csrc/online-recognizer.cc | 1 - sherpa-onnx/csrc/provider-config.cc | 5 +-- sherpa-onnx/csrc/provider-config.h | 25 +++++------ sherpa-onnx/csrc/session.cc | 49 ++++++++++++---------- sherpa-onnx/python/csrc/provider-config.cc | 27 ++++++------ sherpa-onnx/python/csrc/provider-config.h | 2 +- 8 files changed, 60 insertions(+), 53 deletions(-) diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 19d14c3ff..9913fa9ed 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -86,7 +86,7 @@ bool OnlineModelConfig::Validate() const { } if (!provider_config.Validate()) { - return provider_config.Validate(); + return false; } return transducer.Validate(); diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 6a88376a1..0b64e06de 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -53,7 +53,7 @@ struct OnlineModelConfig { const OnlineNeMoCtcModelConfig &nemo_ctc, const ProviderConfig &provider_config, const std::string &tokens, int32_t num_threads, - int32_t warm_up, bool debug, const std::string &provider, + int32_t warm_up, bool debug, const std::string &model_type, const std::string &modeling_unit, const std::string &bpe_vocab) diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 436abb82d..599a0553d 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -113,7 +113,6 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { } bool OnlineRecognizerConfig::Validate() const { - SHERPA_ONNX_LOGE("Args recognizer : %s",ToString().c_str()); if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) { if (max_active_paths <= 0) { SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d", diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 4d61eac9d..cf4a78a24 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -18,8 +18,8 @@ void CudaConfig::Register(ParseOptions *po) { bool CudaConfig::Validate() const { - if(cudnn_conv_algo_search > 0 && cudnn_conv_algo_search < 4) { - SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." + if(cudnn_conv_algo_search < 1 && cudnn_conv_algo_search > 3) { + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." "Options : [1,3]. Check OnnxRT docs", cudnn_conv_algo_search); return false; @@ -96,7 +96,6 @@ bool TensorrtConfig::Validate() const { trt_timing_cache_enable); return false; } - if (trt_dump_subgraphs != true || trt_dump_subgraphs != false) { SHERPA_ONNX_LOGE("trt_dump_subgraphs: '%d' is not valid.", trt_dump_subgraphs); diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 9343e775d..093a44dda 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -17,7 +17,7 @@ struct CudaConfig { uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; - CudaConfig(const uint32_t &cudnn_conv_algo_search) + CudaConfig(uint32_t &cudnn_conv_algo_search) : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); @@ -39,15 +39,16 @@ struct TensorrtConfig { bool trt_dump_subgraphs = 0; TensorrtConfig() = default; - TensorrtConfig(const uint32_t &trt_max_workspace_size, - const uint32_t &trt_max_partition_iterations, - const uint32_t &trt_min_subgraph_size, - const bool &trt_fp16_enable, - const bool &trt_detailed_build_log, - const bool &trt_engine_cache_enable, - const bool &trt_timing_cache_enable, + TensorrtConfig(uint32_t trt_max_workspace_size, + uint32_t trt_max_partition_iterations, + uint32_t trt_min_subgraph_size, + bool trt_fp16_enable, + bool trt_detailed_build_log, + bool trt_engine_cache_enable, + bool trt_timing_cache_enable, const std::string &trt_engine_cache_path, - const std::string &trt_timing_cache_path) + const std::string &trt_timing_cache_path, + bool trt_dump_subgraphs) : trt_max_workspace_size(trt_max_workspace_size), trt_max_partition_iterations(trt_max_partition_iterations), trt_min_subgraph_size(trt_min_subgraph_size), @@ -56,8 +57,8 @@ struct TensorrtConfig { trt_engine_cache_enable(trt_engine_cache_enable), trt_timing_cache_enable(trt_timing_cache_enable), trt_engine_cache_path(trt_engine_cache_path), - trt_timing_cache_path(trt_timing_cache_path, - trt_dump_subgraphs(trt_dump_subgraphs)) {} + trt_timing_cache_path(trt_timing_cache_path), + trt_dump_subgraphs(trt_dump_subgraphs) {} void Register(ParseOptions *po); bool Validate() const; @@ -73,7 +74,7 @@ struct ProviderConfig { TensorrtConfig trt_config; ProviderConfig() = default; - ProviderConfig(const uint32_t &device, + ProviderConfig(uint32_t device, const std::string &provider, CudaConfig &cuda_config, TensorrtConfig &trt_config) diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 620b27158..615c29718 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -32,7 +32,7 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { } static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, - std::string provider_str, + const std::string &provider_str, const ProviderConfig *provider_config=nullptr) { Provider p = StringToProvider(std::move(provider_str)); @@ -77,29 +77,36 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, const char *op_values; }; + auto trt_max_workspace_size = + std::to_string(trt_config.trt_max_workspace_size); + auto trt_max_partition_iterations = + std::to_string(trt_config.trt_max_partition_iterations); + auto trt_min_subgraph_size = + std::to_string(trt_config.trt_min_subgraph_size); + auto trt_fp16_enable = + std::to_string(trt_config.trt_fp16_enable); + auto trt_detailed_build_log = + std::to_string(trt_config.trt_detailed_build_log); + auto trt_engine_cache_enable = + std::to_string(trt_config.trt_engine_cache_enable); + auto trt_timing_cache_enable = + std::to_string(trt_config.trt_timing_cache_enable); + auto trt_dump_subgraphs = + std::to_string(trt_config.trt_dump_subgraphs); + std::vector trt_options = { {"device_id", std::to_string(provider_config->device).c_str()}, - {"trt_max_workspace_size", - std::to_string(trt_config.trt_max_workspace_size).c_str()}, - {"trt_max_partition_iterations", - std::to_string(trt_config.trt_max_partition_iterations).c_str()}, - {"trt_min_subgraph_size", - std::to_string(trt_config.trt_min_subgraph_size).c_str()}, - {"trt_fp16_enable", - std::to_string(trt_config.trt_fp16_enable).c_str()}, - {"trt_detailed_build_log", - std::to_string(trt_config.trt_detailed_build_log).c_str()}, - {"trt_engine_cache_enable", - std::to_string(trt_config.trt_engine_cache_enable).c_str()}, - {"trt_engine_cache_path", - trt_config.trt_engine_cache_path.c_str()}, - {"trt_timing_cache_enable", - std::to_string(trt_config.trt_timing_cache_enable).c_str()}, - {"trt_timing_cache_path", - trt_config.trt_timing_cache_path.c_str()}, - {"trt_dump_subgraphs", - std::to_string(trt_config.trt_dump_subgraphs).c_str()} + {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, + {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, + {"trt_fp16_enable", trt_fp16_enable.c_str()}, + {"trt_detailed_build_log",trt_detailed_build_log.c_str()}, + {"trt_engine_cache_enable",trt_engine_cache_enable.c_str()}, + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()} }; // ToDo : Trt configs // "trt_int8_enable" diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 051856d81..45d79f2b6 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -14,7 +14,7 @@ namespace sherpa_onnx { void PybindCudaConfig(py::module *m) { using PyClass = PybindCudaConfig; py::class_(*m, "PybindCudaConfig") - .def(py::init(), + .def(py::init(), py::arg("cudnn_conv_algo_search") = 1) .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) @@ -25,16 +25,16 @@ void PybindCudaConfig(py::module *m) { void PybindTensorrtConfig(py::module *m) { using PyClass = PybindTensorrtConfig; py::class_(*m, "PybindTensorrtConfig") - .def(py::init(), + bool trt_dump_subgraphs>(), py::arg("trt_max_workspace_size") = 2147483648, py::arg("trt_max_partition_iterations") = 10, py::arg("trt_min_subgraph_size") = 5, @@ -74,14 +74,15 @@ void PybindProviderConfig(py::module *m) { using PyClass = ProviderConfig; py::class_(*m, "ProviderConfig") .def(py::init(), + const std::string &, uint32_t>(), py::arg("trt_config") = TensorrtConfig(), py::arg("cuda_config") = CudaConfig(), py::arg("provider") = "cpu", py::arg("device") = 0) - .def_readwrite("trt_config", &PyClass::Ten) - .def_readwrite("decoder", &PyClass::decoder) - .def_readwrite("joiner", &PyClass::joiner) + .def_readwrite("cuda_config", &PyClass::cuda_config) + .def_readwrite("trt_config", &PyClass::trt_config) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("device", &PyClass::device) .def("__str__", &PyClass::ToString) .def("validate", &PyClass::Validate);} diff --git a/sherpa-onnx/python/csrc/provider-config.h b/sherpa-onnx/python/csrc/provider-config.h index ea0d94e5a..75db543a5 100644 --- a/sherpa-onnx/python/csrc/provider-config.h +++ b/sherpa-onnx/python/csrc/provider-config.h @@ -1,6 +1,6 @@ // sherpa-onnx/python/csrc/provider-config.h // -// Copyright (c) 2024 Uniphore Pvt Ltd(github.com/manickavela29) +// Copyright (c) 2024 Uniphore (Author: Manickavela) #ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ #define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ From f20157c59d27ab0100a6c81f1f5bd7410550d475 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 04:07:38 +0000 Subject: [PATCH 10/34] lint fix Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.cc | 35 +++++++++++++---------------- sherpa-onnx/csrc/provider-config.h | 12 +++++----- sherpa-onnx/csrc/session.cc | 20 ++++++++--------- 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index cf4a78a24..acb862f45 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -17,14 +17,12 @@ void CudaConfig::Register(ParseOptions *po) { } bool CudaConfig::Validate() const { - - if(cudnn_conv_algo_search < 1 && cudnn_conv_algo_search > 3) { + if (cudnn_conv_algo_search < 1 && cudnn_conv_algo_search > 3) { SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." "Options : [1,3]. Check OnnxRT docs", cudnn_conv_algo_search); return false; } - return true; } @@ -38,30 +36,29 @@ std::string CudaConfig::ToString() const { } void TensorrtConfig::Register(ParseOptions *po) { - po->Register("trt-max-workspace-size",&trt_max_workspace_size, + po->Register("trt-max-workspace-size", &trt_max_workspace_size, ""); - po->Register("trt-max-partition-iterations",&trt_max_partition_iterations, + po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, ""); - po->Register("trt-min-subgraph-size ",&trt_min_subgraph_size, + po->Register("trt-min-subgraph-size ", &trt_min_subgraph_size, ""); - po->Register("trt-fp16-enable",&trt_fp16_enable, + po->Register("trt-fp16-enable", &trt_fp16_enable, "true to enable fp16"); - po->Register("trt-detailed-build-log",&trt_detailed_build_log, + po->Register("trt-detailed-build-log", &trt_detailed_build_log, "true to print TensorRT build logs"); - po->Register("trt-engine-cache-enable",&trt_engine_cache_enable, + po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, "true to enable engine caching"); - po->Register("trt-engine-cache-path",&trt_engine_cache_path, + po->Register("trt-engine-cache-path", &trt_engine_cache_path, ""); - po->Register("trt-timing-cache-enable",&trt_timing_cache_enable, + po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, "true to enable timing cache"); - po->Register("trt-timing-cache-path",&trt_timing_cache_path, + po->Register("trt-timing-cache-path", &trt_timing_cache_path, ""); - po->Register("trt-dump-subgraphs",&trt_dump_subgraphs, + po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, "true to dump subgraphs"); } bool TensorrtConfig::Validate() const { - if (trt_max_workspace_size > 0) { SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.", trt_max_workspace_size); @@ -78,7 +75,7 @@ bool TensorrtConfig::Validate() const { return false; } if (trt_fp16_enable != true || trt_fp16_enable != false) { - SHERPA_ONNX_LOGE("trt_fp16_enable: '%d' is not valid.",trt_fp16_enable); + SHERPA_ONNX_LOGE("trt_fp16_enable: '%d' is not valid.", trt_fp16_enable); return false; } if (trt_detailed_build_log != true || trt_detailed_build_log != false) { @@ -123,7 +120,7 @@ std::string TensorrtConfig::ToString() const { << trt_engine_cache_path.c_str() << "\", "; os << "trt_timing_cache_enable=\"" << (trt_timing_cache_enable? "True" : "False") << "\", "; - os << "trt_timing_cache_path=\"" + os << "trt_timing_cache_path=\"" << trt_timing_cache_path.c_str() << "\","; os << "trt_dump_subgraphs=\"" << (trt_dump_subgraphs? "True" : "False") << "\" )"; @@ -137,12 +134,10 @@ void ProviderConfig::Register(ParseOptions *po) { } bool ProviderConfig::Validate() const { - - if(device < 0) { + if (device < 0) { SHERPA_ONNX_LOGE("device: '%d' is invalid.",device); return false; } - return true; } @@ -151,7 +146,7 @@ std::string ProviderConfig::ToString() const { os << "ProviderConfig("; os << "device=\"" << device << "\", "; - os << "provider=\"" << provider << "\", "; + os << "provider=\"" << provider << "\", "; os << "cuda_config=\"" << cuda_config.ToString() << "\", "; os << "trt_config=\"" << trt_config.ToString() << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 093a44dda..ef887fff1 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -2,8 +2,8 @@ // // Copyright (c) 2024 Uniphore (Author: Manickavela) -#ifndef SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ -#define SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ +#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ #include @@ -17,7 +17,7 @@ struct CudaConfig { uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; - CudaConfig(uint32_t &cudnn_conv_algo_search) + CudaConfig(uint32_t cudnn_conv_algo_search) : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); @@ -76,8 +76,8 @@ struct ProviderConfig { ProviderConfig() = default; ProviderConfig(uint32_t device, const std::string &provider, - CudaConfig &cuda_config, - TensorrtConfig &trt_config) + const CudaConfig &cuda_config, + const TensorrtConfig &trt_config) : device(device), provider(provider), cuda_config(cuda_config), trt_config(trt_config) {} @@ -90,4 +90,4 @@ struct ProviderConfig { } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_ONNXRT_EXECUTION_PROVIDER_CONFIG_H_ +#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_ diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 615c29718..395fe2ec3 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -33,7 +33,7 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, const std::string &provider_str, - const ProviderConfig *provider_config=nullptr) { + const ProviderConfig *provider_config = nullptr) { Provider p = StringToProvider(std::move(provider_str)); Ort::SessionOptions sess_opts; @@ -66,7 +66,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, break; } case Provider::kTRT: { - if(provider_config == nullptr) { + if (provider_config == nullptr) { SHERPA_ONNX_LOGE("Tensorrt support for Online models ony," "Must be extended for offline and others"); exit(1); @@ -83,12 +83,12 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::to_string(trt_config.trt_max_partition_iterations); auto trt_min_subgraph_size = std::to_string(trt_config.trt_min_subgraph_size); - auto trt_fp16_enable = + auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable); auto trt_detailed_build_log = std::to_string(trt_config.trt_detailed_build_log); auto trt_engine_cache_enable = - std::to_string(trt_config.trt_engine_cache_enable); + std::to_string(trt_config.trt_engine_cache_enable); auto trt_timing_cache_enable = std::to_string(trt_config.trt_timing_cache_enable); auto trt_dump_subgraphs = @@ -101,8 +101,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, {"trt_fp16_enable", trt_fp16_enable.c_str()}, - {"trt_detailed_build_log",trt_detailed_build_log.c_str()}, - {"trt_engine_cache_enable",trt_engine_cache_enable.c_str()}, + {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, @@ -150,9 +150,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, // The CUDA provider is available, proceed with setting the options OrtCUDAProviderOptions options; - if(provider_config != nullptr) { + if (provider_config != nullptr) { options.device_id = provider_config->device; - options.cudnn_conv_algo_search = + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch(provider_config->cuda_config .cudnn_conv_algo_search); } else { @@ -220,11 +220,11 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, - config.provider_config.provider,&config.provider_config); + config.provider_config.provider, &config.provider_config); } Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { - return GetSessionOptionsImpl(config.num_threads,config.provider); + return GetSessionOptionsImpl(config.num_threads, config.provider); } Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) { From 22d73e593b1eb3ed09c77593bc26b35d9326583d Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 04:13:41 +0000 Subject: [PATCH 11/34] device_id fix Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/session.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 395fe2ec3..4b3f9fd26 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -77,6 +77,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, const char *op_values; }; + auto device_id = std::to_string(provider_config->device); auto trt_max_workspace_size = std::to_string(trt_config.trt_max_workspace_size); auto trt_max_partition_iterations = @@ -95,8 +96,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::to_string(trt_config.trt_dump_subgraphs); std::vector trt_options = { - {"device_id", - std::to_string(provider_config->device).c_str()}, + {"device_id", device_id.c_str()}, {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, From e63dc19f49724910689c8f2c0e30047ccef752b7 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 05:42:31 +0000 Subject: [PATCH 12/34] tidy-clang and lint Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.cc | 2 +- sherpa-onnx/csrc/provider-config.h | 2 +- sherpa-onnx/csrc/session.cc | 4 ++-- sherpa-onnx/python/csrc/provider-config.cc | 20 +++++++++++--------- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index acb862f45..53742231b 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -135,7 +135,7 @@ void ProviderConfig::Register(ParseOptions *po) { bool ProviderConfig::Validate() const { if (device < 0) { - SHERPA_ONNX_LOGE("device: '%d' is invalid.",device); + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); return false; } return true; diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index ef887fff1..00098b3c8 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -17,7 +17,7 @@ struct CudaConfig { uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; - CudaConfig(uint32_t cudnn_conv_algo_search) + explicit CudaConfig(uint32_t cudnn_conv_algo_search) : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 4b3f9fd26..b6fdaaa84 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -34,7 +34,7 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, const std::string &provider_str, const ProviderConfig *provider_config = nullptr) { - Provider p = StringToProvider(std::move(provider_str)); + Provider p = StringToProvider(provider_str); Ort::SessionOptions sess_opts; sess_opts.SetIntraOpNumThreads(num_threads); @@ -86,7 +86,7 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::to_string(trt_config.trt_min_subgraph_size); auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable); - auto trt_detailed_build_log = + auto trt_detailed_build_log = std::to_string(trt_config.trt_detailed_build_log); auto trt_engine_cache_enable = std::to_string(trt_config.trt_engine_cache_enable); diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 45d79f2b6..7d3c7483a 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -3,17 +3,17 @@ // Copyright (c) 2024 Uniphore (Author: Manickavela) -#include "sherpa-onnx/csrc/provider-config.h" +#include "sherpa-onnx/python/csrc/provider-config.h" #include -#include "sherpa-onnx/python/csrc/provider-config.h" +#include "sherpa-onnx/csrc/provider-config.h" namespace sherpa_onnx { -void PybindCudaConfig(py::module *m) { - using PyClass = PybindCudaConfig; - py::class_(*m, "PybindCudaConfig") +static void PybindCudaConfig(py::module *m) { + using PyClass = CudaConfig; + py::class_(*m, "CudaConfig") .def(py::init(), py::arg("cudnn_conv_algo_search") = 1) .def_readwrite("cudnn_conv_algo_search", @@ -22,9 +22,9 @@ void PybindCudaConfig(py::module *m) { .def("validate", &PyClass::Validate); } -void PybindTensorrtConfig(py::module *m) { - using PyClass = PybindTensorrtConfig; - py::class_(*m, "PybindTensorrtConfig") +static void PybindTensorrtConfig(py::module *m) { + using PyClass = TensorrtConfig; + py::class_(*m, "TensorrtConfig") .def(py::init(*m, "ProviderConfig") .def(py::init Date: Wed, 26 Jun 2024 11:16:40 +0530 Subject: [PATCH 13/34] Update sherpa-onnx/python/csrc/provider-config.cc Co-authored-by: Fangjun Kuang --- sherpa-onnx/python/csrc/provider-config.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 7d3c7483a..0a84b3526 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -1,4 +1,4 @@ -// sherpa-onnx/python/csrc/provider-config.h +// sherpa-onnx/python/csrc/provider-config.cc // // Copyright (c) 2024 Uniphore (Author: Manickavela) From 9c6be4f43137da977207164ddada2aa70bce5f9e Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 05:50:05 +0000 Subject: [PATCH 14/34] updating python api cmake Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/csrc/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 6d61d11dd..428d8e001 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -30,6 +30,7 @@ set(srcs online-transducer-model-config.cc online-wenet-ctc-model-config.cc online-zipformer2-ctc-model-config.cc + provider-config.cc sherpa-onnx.cc silero-vad-model-config.cc speaker-embedding-extractor.cc From 0022ae450af89024e996492485c03617d2d9a008 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 11:21:46 +0000 Subject: [PATCH 15/34] pybind fix Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.h | 12 ++-- .../python/csrc/online-model-config.cc | 11 +-- sherpa-onnx/python/csrc/provider-config.cc | 68 ++----------------- sherpa-onnx/python/csrc/provider-config.h | 2 +- 4 files changed, 17 insertions(+), 76 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 00098b3c8..f230d7887 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -67,17 +67,17 @@ struct TensorrtConfig { }; struct ProviderConfig { + TensorrtConfig trt_config; + CudaConfig cuda_config; + std::string provider = "cpu"; uint32_t device = 0; // device only used for cuda and trt - std::string provider = "cpu"; - CudaConfig cuda_config; - TensorrtConfig trt_config; ProviderConfig() = default; - ProviderConfig(uint32_t device, - const std::string &provider, + ProviderConfig(const TensorrtConfig &trt_config, const CudaConfig &cuda_config, - const TensorrtConfig &trt_config) + const std::string &provider, + uint32_t device) : device(device), provider(provider), cuda_config(cuda_config), trt_config(trt_config) {} diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 982601805..5ce9bb92f 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -8,13 +8,13 @@ #include #include "sherpa-onnx/csrc/online-model-config.h" -#include "sherpa-onnx/csrc/provider-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" #include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h" +#include "sherpa-onnx/python/csrc/provider-config.h" namespace sherpa_onnx { @@ -36,7 +36,7 @@ void PybindOnlineModelConfig(py::module *m) { const ProviderConfig &, const std::string &, int32_t, int32_t, bool, const std::string &, const std::string &, - const std::string &, const std::string &>(), + const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), @@ -44,16 +44,17 @@ void PybindOnlineModelConfig(py::module *m) { py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), py::arg("provider_config") = ProviderConfig(), py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, - py::arg("debug") = false, py::arg("provider") = "cpu", - py::arg("model_type") = "", py::arg("modeling_unit") = "", - py::arg("bpe_vocab") = "") + py::arg("debug") = false, py::arg("model_type") = "", + py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) .def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) + .def_readwrite("provider_config", &PyClass::provider_config) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("warm_up", &PyClass::warm_up) .def_readwrite("debug", &PyClass::debug) .def_readwrite("model_type", &PyClass::model_type) .def_readwrite("modeling_unit", &PyClass::modeling_unit) diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 0a84b3526..307686377 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -1,6 +1,6 @@ // sherpa-onnx/python/csrc/provider-config.cc // -// Copyright (c) 2024 Uniphore (Author: Manickavela) +// Copyright (c) 2024 Uniphore (Author: Manickavela A) #include "sherpa-onnx/python/csrc/provider-config.h" @@ -11,72 +11,12 @@ namespace sherpa_onnx { -static void PybindCudaConfig(py::module *m) { - using PyClass = CudaConfig; - py::class_(*m, "CudaConfig") - .def(py::init(), - py::arg("cudnn_conv_algo_search") = 1) - .def_readwrite("cudnn_conv_algo_search", - &PyClass::cudnn_conv_algo_search) - .def("__str__", &PyClass::ToString) - .def("validate", &PyClass::Validate); -} - -static void PybindTensorrtConfig(py::module *m) { - using PyClass = TensorrtConfig; - py::class_(*m, "TensorrtConfig") - .def(py::init(), - py::arg("trt_max_workspace_size") = 2147483648, - py::arg("trt_max_partition_iterations") = 10, - py::arg("trt_min_subgraph_size") = 5, - py::arg("trt_fp16_enable") = 1, - py::arg("trt_detailed_build_log") = 0, - py::arg("trt_engine_cache_enable") = 1, - py::arg("trt_timing_cache_enable") = 1, - py::arg("trt_engine_cache_path") = ".", - py::arg("trt_timing_cache_path") = ".", - py::arg("trt_dump_subgraphs") = 0) - .def_readwrite("trt_max_workspace_size", - &PyClass::trt_max_workspace_size) - .def_readwrite("trt_max_partition_iterations", - &PyClass::trt_max_partition_iterations) - .def_readwrite("trt_min_subgraph_size", - &PyClass::trt_min_subgraph_size) - .def_readwrite("trt_fp16_enable", - &PyClass::trt_fp16_enable) - .def_readwrite("trt_detailed_build_log", - &PyClass::trt_detailed_build_log) - .def_readwrite("trt_engine_cache_enable", - &PyClass::trt_engine_cache_enable) - .def_readwrite("trt_timing_cache_enable", - &PyClass::trt_timing_cache_enable) - .def_readwrite("trt_engine_cache_path", - &PyClass::trt_engine_cache_path) - .def_readwrite("trt_timing_cache_path", - &PyClass::trt_timing_cache_path) - .def_readwrite("trt_dump_subgraphs", - &PyClass::trt_dump_subgraphs) - .def("__str__", &PyClass::ToString) - .def("validate", &PyClass::Validate); -} - void PybindProviderConfig(py::module *m) { - PybindCudaConfig(m); - PybindTensorrtConfig(m); - using PyClass = ProviderConfig; py::class_(*m, "ProviderConfig") - .def(py::init(), + .def(py::init(), py::arg("trt_config") = TensorrtConfig(), py::arg("cuda_config") = CudaConfig(), py::arg("provider") = "cpu", diff --git a/sherpa-onnx/python/csrc/provider-config.h b/sherpa-onnx/python/csrc/provider-config.h index 75db543a5..76377dde2 100644 --- a/sherpa-onnx/python/csrc/provider-config.h +++ b/sherpa-onnx/python/csrc/provider-config.h @@ -1,6 +1,6 @@ // sherpa-onnx/python/csrc/provider-config.h // -// Copyright (c) 2024 Uniphore (Author: Manickavela) +// Copyright (c) 2024 Uniphore (Author: Manickavela A) #ifndef SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ #define SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ From b4f1985f4045abf39e8d7bb1a075bd9bacce309e Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 11:50:30 +0000 Subject: [PATCH 16/34] fix int32_t Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.h | 20 ++++++++++---------- sherpa-onnx/python/csrc/provider-config.cc | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index f230d7887..a9d0ed9fe 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -14,10 +14,10 @@ namespace sherpa_onnx { struct CudaConfig { - uint32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; + int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; - explicit CudaConfig(uint32_t cudnn_conv_algo_search) + explicit CudaConfig(int32_t cudnn_conv_algo_search) : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); @@ -27,9 +27,9 @@ struct CudaConfig { }; struct TensorrtConfig { - uint32_t trt_max_workspace_size = 2147483648; - uint32_t trt_max_partition_iterations = 10; - uint32_t trt_min_subgraph_size = 5; + int32_t trt_max_workspace_size = 2147483648; + int32_t trt_max_partition_iterations = 10; + int32_t trt_min_subgraph_size = 5; bool trt_fp16_enable = 1; bool trt_detailed_build_log = 0; bool trt_engine_cache_enable = 1; @@ -39,9 +39,9 @@ struct TensorrtConfig { bool trt_dump_subgraphs = 0; TensorrtConfig() = default; - TensorrtConfig(uint32_t trt_max_workspace_size, - uint32_t trt_max_partition_iterations, - uint32_t trt_min_subgraph_size, + TensorrtConfig(int32_t trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, bool trt_fp16_enable, bool trt_detailed_build_log, bool trt_engine_cache_enable, @@ -70,14 +70,14 @@ struct ProviderConfig { TensorrtConfig trt_config; CudaConfig cuda_config; std::string provider = "cpu"; - uint32_t device = 0; + int32_t device = 0; // device only used for cuda and trt ProviderConfig() = default; ProviderConfig(const TensorrtConfig &trt_config, const CudaConfig &cuda_config, const std::string &provider, - uint32_t device) + int32_t device) : device(device), provider(provider), cuda_config(cuda_config), trt_config(trt_config) {} diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 307686377..8593b4117 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -16,7 +16,7 @@ void PybindProviderConfig(py::module *m) { py::class_(*m, "ProviderConfig") .def(py::init(), + int32_t>(), py::arg("trt_config") = TensorrtConfig(), py::arg("cuda_config") = CudaConfig(), py::arg("provider") = "cpu", From 40119487eeef8c4a82d640051982b0533a86c6c6 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 12:19:06 +0000 Subject: [PATCH 17/34] uint32_t back and jni fix Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.h | 12 ++++++------ sherpa-onnx/jni/online-recognizer.cc | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index a9d0ed9fe..77037dd04 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -27,9 +27,9 @@ struct CudaConfig { }; struct TensorrtConfig { - int32_t trt_max_workspace_size = 2147483648; - int32_t trt_max_partition_iterations = 10; - int32_t trt_min_subgraph_size = 5; + uint32_t trt_max_workspace_size = 2147483648; + uint32_t trt_max_partition_iterations = 10; + uint32_t trt_min_subgraph_size = 5; bool trt_fp16_enable = 1; bool trt_detailed_build_log = 0; bool trt_engine_cache_enable = 1; @@ -39,9 +39,9 @@ struct TensorrtConfig { bool trt_dump_subgraphs = 0; TensorrtConfig() = default; - TensorrtConfig(int32_t trt_max_workspace_size, - int32_t trt_max_partition_iterations, - int32_t trt_min_subgraph_size, + TensorrtConfig(uint32_t trt_max_workspace_size, + uint32_t trt_max_partition_iterations, + uint32_t trt_min_subgraph_size, bool trt_fp16_enable, bool trt_detailed_build_log, bool trt_engine_cache_enable, diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index d8acd0fed..643b037b3 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -198,7 +198,7 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; + ans.model_config.provider_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); From 1cfd4690f3002694ddd3dbe5a0a72cc9cde05d51 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 12:41:36 +0000 Subject: [PATCH 18/34] JNI and python fix signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/jni/keyword-spotter.cc | 2 +- sherpa-onnx/jni/offline-recognizer.cc | 2 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc index 7a05b4855..ca0c229c2 100644 --- a/sherpa-onnx/jni/keyword-spotter.cc +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -94,7 +94,7 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; + ans.model_config.provider_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 070d46f08..3c5ce5fb1 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -79,7 +79,7 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider = p; + ans.model_config.provider_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 242d85974..8dd0a38f8 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -29,6 +29,7 @@ #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" #include "sherpa-onnx/python/csrc/wave-writer.h" +#include "sherpa-onnx/python/csrc/provider-config.h" #if SHERPA_ONNX_ENABLE_TTS == 1 #include "sherpa-onnx/python/csrc/offline-tts.h" @@ -44,6 +45,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOfflinePunctuation(&m); PybindFeatures(&m); + PybindProviderConfig(&m); PybindOnlineCtcFstDecoderConfig(&m); PybindOnlineModelConfig(&m); PybindOnlineLMConfig(&m); From 2ef7a7c84f27344461b8c1e950b44dd6d2d5110d Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Wed, 26 Jun 2024 21:07:54 +0530 Subject: [PATCH 19/34] Update sherpa-onnx/csrc/provider-config.cc Co-authored-by: Fangjun Kuang --- sherpa-onnx/csrc/provider-config.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 53742231b..3911be8a2 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -18,7 +18,7 @@ void CudaConfig::Register(ParseOptions *po) { bool CudaConfig::Validate() const { if (cudnn_conv_algo_search < 1 && cudnn_conv_algo_search > 3) { - SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not valid option." + SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." "Options : [1,3]. Check OnnxRT docs", cudnn_conv_algo_search); return false; From fa001b7aad8a4a8178297d09b4f119ccedc2930a Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 26 Jun 2024 18:07:13 +0000 Subject: [PATCH 20/34] removing from offline Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/jni/offline-recognizer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 3c5ce5fb1..070d46f08 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -79,7 +79,7 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); - ans.model_config.provider_config.provider = p; + ans.model_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); From be3aa27e26edf6fb733e9ee01a644e7594bac264 Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Thu, 27 Jun 2024 07:30:50 +0530 Subject: [PATCH 21/34] Apply suggestions from code review Co-authored-by: Fangjun Kuang --- sherpa-onnx/csrc/provider-config.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 3911be8a2..b07151d8e 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -17,7 +17,7 @@ void CudaConfig::Register(ParseOptions *po) { } bool CudaConfig::Validate() const { - if (cudnn_conv_algo_search < 1 && cudnn_conv_algo_search > 3) { + if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) { SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option." "Options : [1,3]. Check OnnxRT docs", cudnn_conv_algo_search); @@ -30,7 +30,7 @@ std::string CudaConfig::ToString() const { std::ostringstream os; os << "CudaConfig("; - os << "cudnn_conv_algo_search=\"" << cudnn_conv_algo_search << ")"; + os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")"; return os.str(); } @@ -60,7 +60,7 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size > 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: '%d' is not valid.", + SHERPA_ONNX_LOGE("trt_max_workspace_size: '%u' is not valid.", trt_max_workspace_size); return false; } From e2dc60cdde9977db08dea90e9f41c747502f7bf5 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 27 Jun 2024 02:30:45 +0000 Subject: [PATCH 22/34] bug fix Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.cc | 35 ++++++++++++++++------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index b07151d8e..0a56ddb87 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -37,40 +37,40 @@ std::string CudaConfig::ToString() const { void TensorrtConfig::Register(ParseOptions *po) { po->Register("trt-max-workspace-size", &trt_max_workspace_size, - ""); + "Set TensorRT EP GPU memory usage limit."); po->Register("trt-max-partition-iterations", &trt_max_partition_iterations, - ""); - po->Register("trt-min-subgraph-size ", &trt_min_subgraph_size, - ""); + "Limit partitioning iterations for model conversion."); + po->Register("trt-min-subgraph-size", &trt_min_subgraph_size, + "Set minimum size for subgraphs in partitioning."); po->Register("trt-fp16-enable", &trt_fp16_enable, - "true to enable fp16"); + "Enable FP16 precision for faster performance."); po->Register("trt-detailed-build-log", &trt_detailed_build_log, - "true to print TensorRT build logs"); + "Enable detailed logging of build steps."); po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, - "true to enable engine caching"); + "Enable caching of TensorRT engines."); po->Register("trt-engine-cache-path", &trt_engine_cache_path, - ""); + "Set path to store cached TensorRT engines."); po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, - "true to enable timing cache"); + "Enable use of timing cache to speed up builds."); po->Register("trt-timing-cache-path", &trt_timing_cache_path, - ""); + "Set path for storing timing cache."); po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, - "true to dump subgraphs"); + "Dump optimized subgraphs for debugging."); } bool TensorrtConfig::Validate() const { - if (trt_max_workspace_size > 0) { + if (trt_max_workspace_size < 0) { SHERPA_ONNX_LOGE("trt_max_workspace_size: '%u' is not valid.", trt_max_workspace_size); return false; } - if (trt_max_partition_iterations > 0) { - SHERPA_ONNX_LOGE("trt_max_partition_iterations: '%d' is not valid.", + if (trt_max_partition_iterations < 0) { + SHERPA_ONNX_LOGE("trt_max_partition_iterations: '%u' is not valid.", trt_max_partition_iterations); return false; } - if (trt_min_subgraph_size > 0) { - SHERPA_ONNX_LOGE("trt_min_subgraph_size: '%d' is not valid.", + if (trt_min_subgraph_size < 0) { + SHERPA_ONNX_LOGE("trt_min_subgraph_size: '%u' is not valid.", trt_min_subgraph_size); return false; } @@ -128,6 +128,9 @@ std::string TensorrtConfig::ToString() const { } void ProviderConfig::Register(ParseOptions *po) { + cuda_config.Register(po); + trt_config.Register(po); + po->Register("device", &device, "GPU device index for CUDA and Trt EP"); po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); From 4f5a58f35c1051dc0c079dc4be1368dc00c30989 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Fri, 28 Jun 2024 08:54:30 +0000 Subject: [PATCH 23/34] handling uint and attempting python fix --- sherpa-onnx/csrc/provider-config.cc | 48 +++++++------------------- sherpa-onnx/csrc/provider-config.h | 27 +++++++-------- sherpa-onnx/python/csrc/sherpa-onnx.cc | 2 -- 3 files changed, 25 insertions(+), 52 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 0a56ddb87..2d4109040 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -48,10 +48,10 @@ void TensorrtConfig::Register(ParseOptions *po) { "Enable detailed logging of build steps."); po->Register("trt-engine-cache-enable", &trt_engine_cache_enable, "Enable caching of TensorRT engines."); - po->Register("trt-engine-cache-path", &trt_engine_cache_path, - "Set path to store cached TensorRT engines."); po->Register("trt-timing-cache-enable", &trt_timing_cache_enable, "Enable use of timing cache to speed up builds."); + po->Register("trt-engine-cache-path", &trt_engine_cache_path, + "Set path to store cached TensorRT engines."); po->Register("trt-timing-cache-path", &trt_timing_cache_path, "Set path for storing timing cache."); po->Register("trt-dump-subgraphs", &trt_dump_subgraphs, @@ -60,44 +60,20 @@ void TensorrtConfig::Register(ParseOptions *po) { bool TensorrtConfig::Validate() const { if (trt_max_workspace_size < 0) { - SHERPA_ONNX_LOGE("trt_max_workspace_size: '%u' is not valid.", + SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.", trt_max_workspace_size); return false; } if (trt_max_partition_iterations < 0) { - SHERPA_ONNX_LOGE("trt_max_partition_iterations: '%u' is not valid.", + SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.", trt_max_partition_iterations); return false; } if (trt_min_subgraph_size < 0) { - SHERPA_ONNX_LOGE("trt_min_subgraph_size: '%u' is not valid.", + SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.", trt_min_subgraph_size); return false; } - if (trt_fp16_enable != true || trt_fp16_enable != false) { - SHERPA_ONNX_LOGE("trt_fp16_enable: '%d' is not valid.", trt_fp16_enable); - return false; - } - if (trt_detailed_build_log != true || trt_detailed_build_log != false) { - SHERPA_ONNX_LOGE("trt_detailed_build_log: '%d' is not valid.", - trt_detailed_build_log); - return false; - } - if (trt_engine_cache_enable != true || trt_engine_cache_enable != false) { - SHERPA_ONNX_LOGE("trt_engine_cache_enable: '%d' is not valid.", - trt_engine_cache_enable); - return false; - } - if (trt_timing_cache_enable != true || trt_timing_cache_enable != false) { - SHERPA_ONNX_LOGE("trt_timing_cache_enable: '%d' is not valid.", - trt_timing_cache_enable); - return false; - } - if (trt_dump_subgraphs != true || trt_dump_subgraphs != false) { - SHERPA_ONNX_LOGE("trt_dump_subgraphs: '%d' is not valid.", - trt_dump_subgraphs); - return false; - } return true; } @@ -106,10 +82,10 @@ std::string TensorrtConfig::ToString() const { std::ostringstream os; os << "TensorrtConfig("; - os << "trt_max_workspace_size=\"" << trt_max_workspace_size << "\", "; - os << "trt_max_partition_iterations=\"" - << trt_max_partition_iterations << "\", "; - os << "trt_min_subgraph_size=\"" << trt_min_subgraph_size << "\", "; + os << "trt_max_workspace_size=" << trt_max_workspace_size << ", "; + os << "trt_max_partition_iterations=" + << trt_max_partition_iterations << ", "; + os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", "; os << "trt_fp16_enable=\"" << (trt_fp16_enable? "True" : "False") << "\", "; os << "trt_detailed_build_log=\"" @@ -148,10 +124,10 @@ std::string ProviderConfig::ToString() const { std::ostringstream os; os << "ProviderConfig("; - os << "device=\"" << device << "\", "; + os << "device=" << device << ", "; os << "provider=\"" << provider << "\", "; - os << "cuda_config=\"" << cuda_config.ToString() << "\", "; - os << "trt_config=\"" << trt_config.ToString() << ")"; + os << "cuda_config=" << cuda_config.ToString() << ", "; + os << "trt_config=" << trt_config.ToString() << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 77037dd04..cb5f70664 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -27,21 +27,21 @@ struct CudaConfig { }; struct TensorrtConfig { - uint32_t trt_max_workspace_size = 2147483648; - uint32_t trt_max_partition_iterations = 10; - uint32_t trt_min_subgraph_size = 5; - bool trt_fp16_enable = 1; - bool trt_detailed_build_log = 0; - bool trt_engine_cache_enable = 1; - bool trt_timing_cache_enable = 1; + int32_t trt_max_workspace_size = 2147483648; + int32_t trt_max_partition_iterations = 10; + int32_t trt_min_subgraph_size = 5; + bool trt_fp16_enable = true; + bool trt_detailed_build_log = false; + bool trt_engine_cache_enable = true; + bool trt_timing_cache_enable = true; std::string trt_engine_cache_path = "."; std::string trt_timing_cache_path = "."; - bool trt_dump_subgraphs = 0; + bool trt_dump_subgraphs = false; TensorrtConfig() = default; - TensorrtConfig(uint32_t trt_max_workspace_size, - uint32_t trt_max_partition_iterations, - uint32_t trt_min_subgraph_size, + TensorrtConfig(int32_t trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, bool trt_fp16_enable, bool trt_detailed_build_log, bool trt_engine_cache_enable, @@ -78,9 +78,8 @@ struct ProviderConfig { const CudaConfig &cuda_config, const std::string &provider, int32_t device) - : device(device), provider(provider), - cuda_config(cuda_config), - trt_config(trt_config) {} + : trt_config(trt_config), cuda_config(cuda_config), + provider(provider), device(device) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 8dd0a38f8..242d85974 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -29,7 +29,6 @@ #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" #include "sherpa-onnx/python/csrc/wave-writer.h" -#include "sherpa-onnx/python/csrc/provider-config.h" #if SHERPA_ONNX_ENABLE_TTS == 1 #include "sherpa-onnx/python/csrc/offline-tts.h" @@ -45,7 +44,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOfflinePunctuation(&m); PybindFeatures(&m); - PybindProviderConfig(&m); PybindOnlineCtcFstDecoderConfig(&m); PybindOnlineModelConfig(&m); PybindOnlineLMConfig(&m); From 389cbf515d69b129f50c04215d91d055cee07c16 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Fri, 28 Jun 2024 12:48:59 +0000 Subject: [PATCH 24/34] pybind attempt Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/csrc/provider-config.cc | 49 ++++++++++++++++++++-- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 8593b4117..8f5b71479 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -11,21 +11,64 @@ namespace sherpa_onnx { +static void PybindCudaConfig(py::module *m) { + using PyClass = CudaConfig; + py::class_(*m, "CudaConfig") + .def(py::init(), + py::arg("cudnn_conv_algo_search") = 1) + .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) + .def("__str__", &PyClass::ToString); +} + +static void PybindTensorrtConfig(py::module *m) { + using PyClass = TensorrtConfig; + py::class_(*m, "TensorrtConfig") + .def(py::init(), + py::arg("trt_max_workspace_size") = 2147483648, + py::arg("trt_max_partition_iterations") = 10, + py::arg("trt_min_subgraph_size") = 5, + py::arg("trt_fp16_enable") = true, + py::arg("trt_detailed_build_log") = false, + py::arg("trt_engine_cache_enable") = true, + py::arg("trt_timing_cache_enable") = true, + py::arg("trt_engine_cache_path") = ".", + py::arg("trt_timing_cache_path") = ".", + py::arg("trt_dump_subgraphs") = false) + .def_readwrite("trt_max_workspace_size", &PyClass::trt_max_workspace_size) + .def_readwrite("trt_max_partition_iterations", &PyClass::trt_max_partition_iterations) + .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) + .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) + .def_readwrite("trt_detailed_build_log", &PyClass::trt_detailed_build_log) + .def_readwrite("trt_engine_cache_enable", &PyClass::trt_engine_cache_enable) + .def_readwrite("trt_timing_cache_enable", &PyClass::trt_timing_cache_enable) + .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) + .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) + .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + void PybindProviderConfig(py::module *m) { + PybindCudaConfig(m); + PybindTensorrtConfig(m); + using PyClass = ProviderConfig; py::class_(*m, "ProviderConfig") .def(py::init(), py::arg("trt_config") = TensorrtConfig(), py::arg("cuda_config") = CudaConfig(), py::arg("provider") = "cpu", py::arg("device") = 0) - .def_readwrite("cuda_config", &PyClass::cuda_config) .def_readwrite("trt_config", &PyClass::trt_config) + .def_readwrite("cuda_config", &PyClass::cuda_config) .def_readwrite("provider", &PyClass::provider) .def_readwrite("device", &PyClass::device) .def("__str__", &PyClass::ToString) - .def("validate", &PyClass::Validate);} + .def("validate", &PyClass::Validate); +} } // namespace sherpa_onnx From bbfea38160774e39f8f665f91e89df7f22079601 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Sat, 29 Jun 2024 02:23:57 +0000 Subject: [PATCH 25/34] python-dump Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/csrc/CMakeLists.txt | 2 + sherpa-onnx/python/csrc/cuda-config.cc | 23 ++++ sherpa-onnx/python/csrc/cuda-config.h | 16 +++ .../python/csrc/online-model-config.cc | 1 - sherpa-onnx/python/csrc/provider-config.cc | 55 ++------ sherpa-onnx/python/csrc/sherpa-onnx.cc | 5 +- sherpa-onnx/python/csrc/sherpa-onnx.h | 2 + sherpa-onnx/python/csrc/tensorrt-config.cc | 117 ++++++++++++++++++ sherpa-onnx/python/csrc/tensorrt-config.h | 16 +++ 9 files changed, 188 insertions(+), 49 deletions(-) create mode 100644 sherpa-onnx/python/csrc/cuda-config.cc create mode 100644 sherpa-onnx/python/csrc/cuda-config.h create mode 100644 sherpa-onnx/python/csrc/tensorrt-config.cc create mode 100644 sherpa-onnx/python/csrc/tensorrt-config.h diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 428d8e001..5e74a1721 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR}) set(srcs audio-tagging.cc circular-buffer.cc + cuda-config.cc display.cc endpoint.cc features.cc @@ -36,6 +37,7 @@ set(srcs speaker-embedding-extractor.cc speaker-embedding-manager.cc spoken-language-identification.cc + tensorrt-config.cc vad-model-config.cc vad-model.cc voice-activity-detector.cc diff --git a/sherpa-onnx/python/csrc/cuda-config.cc b/sherpa-onnx/python/csrc/cuda-config.cc new file mode 100644 index 000000000..4669783d8 --- /dev/null +++ b/sherpa-onnx/python/csrc/cuda-config.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/python/csrc/cuda-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#include "sherpa-onnx/python/csrc/cuda-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/provider-config.h" + +namespace sherpa_onnx { + +void PybindCudaConfig(py::module *m) { + using PyClass = CudaConfig; + py::class_(*m, "CudaConfig") + .def(py::init(), + py::arg("cudnn_conv_algo_search") = 1) + .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/cuda-config.h b/sherpa-onnx/python/csrc/cuda-config.h new file mode 100644 index 000000000..4af175800 --- /dev/null +++ b/sherpa-onnx/python/csrc/cuda-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/cuda-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindCudaConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 5ce9bb92f..b008f424e 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -62,5 +62,4 @@ void PybindOnlineModelConfig(py::module *m) { .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } - } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index 8f5b71479..d0967497d 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -8,59 +8,21 @@ #include #include "sherpa-onnx/csrc/provider-config.h" +// #include "sherpa-onnx/python/csrc/cuda-config.h" +// #include "sherpa-onnx/python/csrc/tensorrt-config.h" namespace sherpa_onnx { -static void PybindCudaConfig(py::module *m) { - using PyClass = CudaConfig; - py::class_(*m, "CudaConfig") - .def(py::init(), - py::arg("cudnn_conv_algo_search") = 1) - .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) - .def("__str__", &PyClass::ToString); -} - -static void PybindTensorrtConfig(py::module *m) { - using PyClass = TensorrtConfig; - py::class_(*m, "TensorrtConfig") - .def(py::init(), - py::arg("trt_max_workspace_size") = 2147483648, - py::arg("trt_max_partition_iterations") = 10, - py::arg("trt_min_subgraph_size") = 5, - py::arg("trt_fp16_enable") = true, - py::arg("trt_detailed_build_log") = false, - py::arg("trt_engine_cache_enable") = true, - py::arg("trt_timing_cache_enable") = true, - py::arg("trt_engine_cache_path") = ".", - py::arg("trt_timing_cache_path") = ".", - py::arg("trt_dump_subgraphs") = false) - .def_readwrite("trt_max_workspace_size", &PyClass::trt_max_workspace_size) - .def_readwrite("trt_max_partition_iterations", &PyClass::trt_max_partition_iterations) - .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) - .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) - .def_readwrite("trt_detailed_build_log", &PyClass::trt_detailed_build_log) - .def_readwrite("trt_engine_cache_enable", &PyClass::trt_engine_cache_enable) - .def_readwrite("trt_timing_cache_enable", &PyClass::trt_timing_cache_enable) - .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) - .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) - .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) - .def("__str__", &PyClass::ToString) - .def("validate", &PyClass::Validate); -} - void PybindProviderConfig(py::module *m) { - PybindCudaConfig(m); - PybindTensorrtConfig(m); + // PybindCudaConfig(m); + // PybindTensorrtConfig(m); using PyClass = ProviderConfig; py::class_(*m, "ProviderConfig") - .def(py::init(), - py::arg("trt_config") = TensorrtConfig(), - py::arg("cuda_config") = CudaConfig(), + .def(py::init<>()) + .def(py::init(), + py::arg("trt_config"), py::arg("cuda_config"), py::arg("provider") = "cpu", py::arg("device") = 0) .def_readwrite("trt_config", &PyClass::trt_config) @@ -70,5 +32,4 @@ void PybindProviderConfig(py::module *m) { .def("__str__", &PyClass::ToString) .def("validate", &PyClass::Validate); } - } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 242d85974..05396c461 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -7,6 +7,7 @@ #include "sherpa-onnx/python/csrc/alsa.h" #include "sherpa-onnx/python/csrc/audio-tagging.h" #include "sherpa-onnx/python/csrc/circular-buffer.h" +// #include "sherpa-onnx/python/csrc/cuda-config.h" #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" @@ -19,12 +20,14 @@ #include "sherpa-onnx/python/csrc/offline-stream.h" #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/online-lm-config.h" +// #include "sherpa-onnx/python/csrc/provider-config.h" #include "sherpa-onnx/python/csrc/online-model-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" #include "sherpa-onnx/python/csrc/spoken-language-identification.h" +// #include "sherpa-onnx/python/csrc/tensorrt-config.h" #include "sherpa-onnx/python/csrc/vad-model-config.h" #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" @@ -51,7 +54,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindEndpoint(&m); PybindOnlineRecognizer(&m); PybindKeywordSpotter(&m); - + // PybindProviderConfig(&m); PybindDisplay(&m); PybindOfflineStream(&m); diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.h b/sherpa-onnx/python/csrc/sherpa-onnx.h index 7bce9f49b..54fd4e88f 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.h +++ b/sherpa-onnx/python/csrc/sherpa-onnx.h @@ -5,6 +5,8 @@ #ifndef SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ #define SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ +#define PYBIND11_DETAILED_ERROR_MESSAGES + #include "pybind11/functional.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" diff --git a/sherpa-onnx/python/csrc/tensorrt-config.cc b/sherpa-onnx/python/csrc/tensorrt-config.cc new file mode 100644 index 000000000..4fcfb3dae --- /dev/null +++ b/sherpa-onnx/python/csrc/tensorrt-config.cc @@ -0,0 +1,117 @@ +// sherpa-onnx/python/csrc/tensorrt-config.cc +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#include "sherpa-onnx/python/csrc/tensorrt-config.h" + +#include + +#include "sherpa-onnx/csrc/provider-config.h" + +namespace sherpa_onnx { + +void PybindTensorrtConfig(py::module *m) { + using PyClass = TensorrtConfig; + py::class_(*m, "TensorrtConfig") + .def(py::init<>()) + .def(py::init([](int32_t trt_max_workspace_size, + int32_t trt_max_partition_iterations, + int32_t trt_min_subgraph_size, + bool trt_fp16_enable, + bool trt_detailed_build_log, + bool trt_engine_cache_enable, + bool trt_timing_cache_enable, + const std::string &trt_engine_cache_path, + const std::string &trt_timing_cache_path, + bool trt_dump_subgraphs) -> std::unique_ptr { + auto ans = std::make_unique(); + + ans->trt_max_workspace_size = trt_max_workspace_size; + ans->trt_max_partition_iterations = trt_max_partition_iterations; + ans->trt_min_subgraph_size = trt_min_subgraph_size; + ans->trt_fp16_enable = trt_fp16_enable; + ans->trt_detailed_build_log = trt_detailed_build_log; + ans->trt_engine_cache_enable = trt_engine_cache_enable; + ans->trt_timing_cache_enable = trt_timing_cache_enable; + ans->trt_engine_cache_path = trt_engine_cache_path; + ans->trt_timing_cache_path = trt_timing_cache_path; + ans->trt_dump_subgraphs = trt_dump_subgraphs; + + return ans; + }), + py::arg("trt_max_workspace_size") = 2147483648, + py::arg("trt_max_partition_iterations") = 10, + py::arg("trt_min_subgraph_size") = 5, + py::arg("trt_fp16_enable") = true, + py::arg("trt_detailed_build_log") = false, + py::arg("trt_engine_cache_enable") = true, + py::arg("trt_timing_cache_enable") = true, + py::arg("trt_engine_cache_path") = ".", + py::arg("trt_timing_cache_path") = ".", + py::arg("trt_dump_subgraphs") = false) + + .def_readwrite("trt_max_workspace_size", &PyClass::trt_max_workspace_size) + .def_readwrite("trt_max_partition_iterations", &PyClass::trt_max_partition_iterations) + .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) + .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) + .def_readwrite("trt_detailed_build_log", &PyClass::trt_detailed_build_log) + .def_readwrite("trt_engine_cache_enable", &PyClass::trt_engine_cache_enable) + .def_readwrite("trt_timing_cache_enable", &PyClass::trt_timing_cache_enable) + .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) + .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) + .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_onnx + + // .def(py::init([](int32_t trt_max_workspace_size, + // int32_t trt_max_partition_iterations, + // int32_t trt_min_subgraph_size, + // bool trt_fp16_enable, + // bool trt_detailed_build_log, + // bool trt_engine_cache_enable, + // bool trt_timing_cache_enable, + // const std::string &trt_engine_cache_path, + // const std::string &trt_timing_cache_path, + // bool trt_dump_subgraphs) -> std::unique_ptr { + // auto ans = std::make_unique(); + + // ans->trt_max_workspace_size = trt_max_workspace_size; + // ans->trt_max_partition_iterations = trt_max_partition_iterations; + // ans->trt_min_subgraph_size = trt_min_subgraph_size; + // ans->trt_fp16_enable = trt_fp16_enable; + // ans->trt_detailed_build_log = trt_detailed_build_log; + // ans->trt_engine_cache_enable = trt_engine_cache_enable; + // ans->trt_timing_cache_enable = trt_timing_cache_enable; + // ans->trt_engine_cache_path = trt_engine_cache_path; + // ans->trt_timing_cache_path = trt_timing_cache_path; + // ans->trt_dump_subgraphs = trt_dump_subgraphs; + + // return ans; + // }), + // py::arg("trt_max_workspace_size") = 2147483648, + // py::arg("trt_max_partition_iterations") = 10, + // py::arg("trt_min_subgraph_size") = 5, + // py::arg("trt_fp16_enable") = true, + // py::arg("trt_detailed_build_log") = false, + // py::arg("trt_engine_cache_enable") = true, + // py::arg("trt_timing_cache_enable") = true, + // py::arg("trt_engine_cache_path") = ".", + // py::arg("trt_timing_cache_path") = ".", + // py::arg("trt_dump_subgraphs") = false) + + // .def(py::init(), + // py::arg("trt_max_workspace_size"), + // py::arg("trt_max_partition_iterations"), + // py::arg("trt_min_subgraph_size"), + // py::arg("trt_fp16_enable"), + // py::arg("trt_detailed_build_log"), + // py::arg("trt_engine_cache_enable"), + // py::arg("trt_timing_cache_enable"), + // py::arg("trt_engine_cache_path"), + // py::arg("trt_timing_cache_path"), + // py::arg("trt_dump_subgraphs")) \ No newline at end of file diff --git a/sherpa-onnx/python/csrc/tensorrt-config.h b/sherpa-onnx/python/csrc/tensorrt-config.h new file mode 100644 index 000000000..b68ae2324 --- /dev/null +++ b/sherpa-onnx/python/csrc/tensorrt-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/tensorrt-config.h +// +// Copyright (c) 2024 Uniphore (Author: Manickavela A) + +#ifndef SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindTensorrtConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ From 2f188bbf8de79c9155661287f2de1643d0807dfe Mon Sep 17 00:00:00 2001 From: Manix Date: Wed, 3 Jul 2024 20:42:33 +0000 Subject: [PATCH 26/34] python interface Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.h | 5 +- sherpa-onnx/python/csrc/cuda-config.cc | 1 + .../python/csrc/online-model-config.cc | 1 + sherpa-onnx/python/csrc/provider-config.cc | 14 +- sherpa-onnx/python/csrc/sherpa-onnx.cc | 4 - sherpa-onnx/python/sherpa_onnx/__init__.py | 1 + .../python/sherpa_onnx/online_recognizer.py | 111 ++- .../sherpa_onnx/online_recognizer_ori.py | 746 ++++++++++++++++++ 8 files changed, 865 insertions(+), 18 deletions(-) create mode 100644 sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index cb5f70664..c3837e3b7 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -17,7 +17,7 @@ struct CudaConfig { int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; - explicit CudaConfig(int32_t cudnn_conv_algo_search) + CudaConfig(int32_t cudnn_conv_algo_search) : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); @@ -74,6 +74,9 @@ struct ProviderConfig { // device only used for cuda and trt ProviderConfig() = default; + ProviderConfig(const std::string &provider, + int32_t device) + : provider(provider), device(device) {} ProviderConfig(const TensorrtConfig &trt_config, const CudaConfig &cuda_config, const std::string &provider, diff --git a/sherpa-onnx/python/csrc/cuda-config.cc b/sherpa-onnx/python/csrc/cuda-config.cc index 4669783d8..43627d3a1 100644 --- a/sherpa-onnx/python/csrc/cuda-config.cc +++ b/sherpa-onnx/python/csrc/cuda-config.cc @@ -14,6 +14,7 @@ namespace sherpa_onnx { void PybindCudaConfig(py::module *m) { using PyClass = CudaConfig; py::class_(*m, "CudaConfig") + .def(py::init<>()) .def(py::init(), py::arg("cudnn_conv_algo_search") = 1) .def_readwrite("cudnn_conv_algo_search", &PyClass::cudnn_conv_algo_search) diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index b008f424e..4ea13fd60 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -9,6 +9,7 @@ #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/provider-config.h" #include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h" #include "sherpa-onnx/python/csrc/online-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/online-transducer-model-config.h" diff --git a/sherpa-onnx/python/csrc/provider-config.cc b/sherpa-onnx/python/csrc/provider-config.cc index d0967497d..c29d48ab4 100644 --- a/sherpa-onnx/python/csrc/provider-config.cc +++ b/sherpa-onnx/python/csrc/provider-config.cc @@ -8,21 +8,25 @@ #include #include "sherpa-onnx/csrc/provider-config.h" -// #include "sherpa-onnx/python/csrc/cuda-config.h" -// #include "sherpa-onnx/python/csrc/tensorrt-config.h" +#include "sherpa-onnx/python/csrc/cuda-config.h" +#include "sherpa-onnx/python/csrc/tensorrt-config.h" namespace sherpa_onnx { void PybindProviderConfig(py::module *m) { - // PybindCudaConfig(m); - // PybindTensorrtConfig(m); + PybindCudaConfig(m); + PybindTensorrtConfig(m); using PyClass = ProviderConfig; py::class_(*m, "ProviderConfig") .def(py::init<>()) + .def(py::init(), + py::arg("provider") = "cpu", + py::arg("device") = 0) .def(py::init(), - py::arg("trt_config"), py::arg("cuda_config"), + py::arg("trt_config") = TensorrtConfig{}, + py::arg("cuda_config") = CudaConfig{}, py::arg("provider") = "cpu", py::arg("device") = 0) .def_readwrite("trt_config", &PyClass::trt_config) diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 05396c461..5b369ed84 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -7,7 +7,6 @@ #include "sherpa-onnx/python/csrc/alsa.h" #include "sherpa-onnx/python/csrc/audio-tagging.h" #include "sherpa-onnx/python/csrc/circular-buffer.h" -// #include "sherpa-onnx/python/csrc/cuda-config.h" #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" @@ -20,14 +19,12 @@ #include "sherpa-onnx/python/csrc/offline-stream.h" #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/online-lm-config.h" -// #include "sherpa-onnx/python/csrc/provider-config.h" #include "sherpa-onnx/python/csrc/online-model-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" #include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" #include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" #include "sherpa-onnx/python/csrc/spoken-language-identification.h" -// #include "sherpa-onnx/python/csrc/tensorrt-config.h" #include "sherpa-onnx/python/csrc/vad-model-config.h" #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" @@ -54,7 +51,6 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindEndpoint(&m); PybindOnlineRecognizer(&m); PybindKeywordSpotter(&m); - // PybindProviderConfig(&m); PybindDisplay(&m); PybindOfflineStream(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 7a832ba06..faccfe3f5 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -16,6 +16,7 @@ OfflineTtsVitsModelConfig, OfflineZipformerAudioTaggingModelConfig, OnlineStream, + ProviderConfig, SileroVadModelConfig, SpeakerEmbeddingExtractor, SpeakerEmbeddingExtractorConfig, diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 82b2e3b42..f73592382 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -11,6 +11,9 @@ ) from _sherpa_onnx import OnlineRecognizer as _Recognizer from _sherpa_onnx import ( + CudaConfig, + TensorrtConfig, + ProviderConfig, OnlineRecognizerConfig, OnlineRecognizerResult, OnlineStream, @@ -56,7 +59,6 @@ def from_transducer( hotwords_score: float = 1.5, blank_penalty: float = 0.0, hotwords_file: str = "", - provider: str = "cpu", model_type: str = "", modeling_unit: str = "cjkchar", bpe_vocab: str = "", @@ -66,6 +68,19 @@ def from_transducer( debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + provider: str = "cpu", + device: int = 0, + cudnn_conv_algo_search: int = 1, + trt_max_workspace_size: int = 2147483648, + trt_max_partition_iterations: int = 10, + trt_min_subgraph_size: int = 5, + trt_fp16_enable: bool = True, + trt_detailed_build_log: bool = False, + trt_engine_cache_enable: bool = True, + trt_timing_cache_enable: bool = True, + trt_engine_cache_path: str ="", + trt_timing_cache_path: str ="", + trt_dump_subgraphs: bool = False, ): """ Please refer to @@ -135,8 +150,6 @@ def from_transducer( Temperature scaling for output symbol confidence estiamation. It affects only confidence values, the decoding uses the original logits without temperature. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. model_type: Online transducer model type. Valid values are: conformer, lstm, zipformer, zipformer2. All other values lead to loading the model twice. @@ -156,6 +169,32 @@ def from_transducer( rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + device: + onnxruntime cuda device index. + cudnn_conv_algo_search: + onxrt CuDNN convolution search algorithm selection. CUDA EP + trt_max_workspace_size: + Set TensorRT EP GPU memory usage limit. TensorRT EP + trt_max_partition_iterations: + Limit partitioning iterations for model conversion. TensorRT EP + trt_min_subgraph_size: + Set minimum size for subgraphs in partitioning. TensorRT EP + trt_fp16_enable: bool = True, + Enable FP16 precision for faster performance. TensorRT EP + trt_detailed_build_log: bool = False, + Enable detailed logging of build steps. TensorRT EP + trt_engine_cache_enable: bool = True, + Enable caching of TensorRT engines. TensorRT EP + trt_timing_cache_enable: bool = True, + "Enable use of timing cache to speed up builds." TensorRT EP + trt_engine_cache_path: str ="", + "Set path to store cached TensorRT engines." TensorRT EP + trt_timing_cache_path: str ="", + "Set path for storing timing cache." TensorRT EP + trt_dump_subgraphs: bool = False, + "Dump optimized subgraphs for debugging." TensorRT EP """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -171,11 +210,35 @@ def from_transducer( joiner=joiner, ) + cuda_config = CudaConfig( + cudnn_conv_algo_search=cudnn_conv_algo_search, + ) + + trt_config = TensorrtConfig( + trt_max_workspace_size=trt_max_workspace_size, + trt_max_partition_iterations=trt_max_partition_iterations, + trt_min_subgraph_size=trt_min_subgraph_size, + trt_fp16_enable=trt_fp16_enable, + trt_detailed_build_log=trt_detailed_build_log, + trt_engine_cache_enable=trt_engine_cache_enable, + trt_timing_cache_enable=trt_timing_cache_enable, + trt_engine_cache_path=trt_engine_cache_path, + trt_timing_cache_path=trt_timing_cache_path, + trt_dump_subgraphs=trt_dump_subgraphs, + ) + + provider_config = ProviderConfig( + trt_config=trt_config, + cuda_config=cuda_config, + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, model_type=model_type, modeling_unit=modeling_unit, bpe_vocab=bpe_vocab, @@ -251,6 +314,7 @@ def from_paraformer( debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -301,6 +365,8 @@ def from_paraformer( rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -314,11 +380,16 @@ def from_paraformer( decoder=decoder, ) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( paraformer=paraformer_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, model_type="paraformer", debug=debug, ) @@ -367,6 +438,7 @@ def from_zipformer2_ctc( debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -421,6 +493,8 @@ def from_zipformer2_ctc( rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -430,11 +504,16 @@ def from_zipformer2_ctc( zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( zipformer2_ctc=zipformer2_ctc_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, debug=debug, ) @@ -486,6 +565,7 @@ def from_nemo_ctc( debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -535,6 +615,8 @@ def from_nemo_ctc( rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -546,11 +628,16 @@ def from_nemo_ctc( model=model, ) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( nemo_ctc=nemo_ctc_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, debug=debug, ) @@ -598,6 +685,7 @@ def from_wenet_ctc( debug: bool = False, rule_fsts: str = "", rule_fars: str = "", + device: int = 0, ): """ Please refer to @@ -650,6 +738,8 @@ def from_wenet_ctc( rule_fars: If not empty, it specifies fst archives for inverse text normalization. If there are multiple archives, they are separated by a comma. + device: + onnxruntime cuda device index. """ self = cls.__new__(cls) _assert_file_exists(tokens) @@ -663,11 +753,16 @@ def from_wenet_ctc( num_left_chunks=num_left_chunks, ) + provider_config = ProviderConfig( + provider=provider, + device=device, + ) + model_config = OnlineModelConfig( wenet_ctc=wenet_ctc_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, debug=debug, ) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py new file mode 100644 index 000000000..cc71689d5 --- /dev/null +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py @@ -0,0 +1,746 @@ +# Copyright (c) 2023 Xiaomi Corporation +from pathlib import Path +from typing import List, Optional + +from _sherpa_onnx import ( + EndpointConfig, + FeatureExtractorConfig, + OnlineLMConfig, + + OnlineModelConfig, + OnlineParaformerModelConfig, +) +from _sherpa_onnx import OnlineRecognizer as _Recognizer +from _sherpa_onnx import ( + OnlineRecognizerConfig, + OnlineRecognizerResult, + OnlineStream, + OnlineTransducerModelConfig, + OnlineWenetCtcModelConfig, + OnlineNeMoCtcModelConfig, + OnlineZipformer2CtcModelConfig, + OnlineCtcFstDecoderConfig, +) + + +def _assert_file_exists(f: str): + assert Path(f).is_file(), f"{f} does not exist" + + +class OnlineRecognizer(object): + """A class for streaming speech recognition. + + Please refer to the following files for usages + - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py + - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py + """ + + @classmethod + def from_transducer( + cls, + tokens: str, + encoder: str, + decoder: str, + joiner: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + low_freq: float = 20.0, + high_freq: float = -400.0, + dither: float = 0.0, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + max_active_paths: int = 4, + hotwords_score: float = 1.5, + blank_penalty: float = 0.0, + hotwords_file: str = "", + provider: str = "cpu", + model_type: str = "", + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", + lm: str = "", + lm_scale: float = 0.1, + temperature_scale: float = 2.0, + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + joiner: + Path to ``joiner.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + low_freq: + Low cutoff frequency for mel bins in feature extraction. + high_freq: + High cutoff frequency for mel bins in feature extraction + (if <= 0, offset from Nyquist) + dither: + Dithering constant (0.0 means no dither). + By default the audio samples are in range [-1,+1], + so dithering constant 0.00003 is a good value, + equivalent to the default 1.0 from kaldi + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + Valid values are greedy_search, modified_beam_search. + max_active_paths: + Use only when decoding_method is modified_beam_search. It specifies + the maximum number of active paths during beam search. + blank_penalty: + The penalty applied on blank symbol during decoding. + hotwords_file: + The file containing hotwords, one words/phrases per line, and for each + phrase the bpe/cjkchar are separated by a space. + hotwords_score: + The hotword score of each token for biasing word/phrase. Used only if + hotwords_file is given with modified_beam_search as decoding method. + temperature_scale: + Temperature scaling for output symbol confidence estiamation. + It affects only confidence values, the decoding uses the original + logits without temperature. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + model_type: + Online transducer model type. Valid values are: conformer, lstm, + zipformer, zipformer2. All other values lead to loading the model twice. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, cjkchar, + cjkchar+bpe, etc. Currently, it is needed only when hotwords are + provided, we need it to encode the hotwords into token sequence. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when hotwords provided + and the modeling unit is bpe or cjkchar+bpe. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + _assert_file_exists(joiner) + + assert num_threads > 0, num_threads + + transducer_config = OnlineTransducerModelConfig( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + + model_config = OnlineModelConfig( + transducer=transducer_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + model_type=model_type, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + low_freq=low_freq, + high_freq=high_freq, + dither=dither, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--hotwords-file. Currently given: {decoding_method}" + ) + + if lm and decoding_method != "modified_beam_search": + raise ValueError( + "Please use --decoding-method=modified_beam_search when using " + f"--lm. Currently given: {decoding_method}" + ) + + lm_config = OnlineLMConfig( + model=lm, + scale=lm_scale, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + lm_config=lm_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + max_active_paths=max_active_paths, + hotwords_score=hotwords_score, + hotwords_file=hotwords_file, + blank_penalty=blank_penalty, + temperature_scale=temperature_scale, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_paraformer( + cls, + tokens: str, + encoder: str, + decoder: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + encoder: + Path to ``encoder.onnx``. + decoder: + Path to ``decoder.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(encoder) + _assert_file_exists(decoder) + + assert num_threads > 0, num_threads + + paraformer_config = OnlineParaformerModelConfig( + encoder=encoder, + decoder=decoder, + ) + + model_config = OnlineModelConfig( + paraformer=paraformer_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + model_type="paraformer", + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_zipformer2_ctc( + cls, + tokens: str, + model: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + ctc_graph: str = "", + ctc_max_active: int = 3000, + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + ctc_graph: + If not empty, decoding_method is ignored. It contains the path to + H.fst, HL.fst, or HLG.fst + ctc_max_active: + Used only when ctc_graph is not empty. It specifies the maximum + active paths at a time. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) + + model_config = OnlineModelConfig( + zipformer2_ctc=zipformer2_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + ctc_fst_decoder_config = OnlineCtcFstDecoderConfig( + graph=ctc_graph, + max_active=ctc_max_active, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + ctc_fst_decoder_config=ctc_fst_decoder_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_nemo_ctc( + cls, + tokens: str, + model: str, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + debug: + True to show meta data in the model. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + nemo_ctc_config = OnlineNeMoCtcModelConfig( + model=model, + ) + + model_config = OnlineModelConfig( + nemo_ctc=nemo_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + @classmethod + def from_wenet_ctc( + cls, + tokens: str, + model: str, + chunk_size: int = 16, + num_left_chunks: int = 4, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + debug: bool = False, + rule_fsts: str = "", + rule_fars: str = "", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + chunk_size: + The --chunk-size parameter from WeNet. + num_left_chunks: + The --num-left-chunks parameter from WeNet. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + rule_fsts: + If not empty, it specifies fsts for inverse text normalization. + If there are multiple fsts, they are separated by a comma. + rule_fars: + If not empty, it specifies fst archives for inverse text normalization. + If there are multiple archives, they are separated by a comma. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + wenet_ctc_config = OnlineWenetCtcModelConfig( + model=model, + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + ) + + model_config = OnlineModelConfig( + wenet_ctc=wenet_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + debug=debug, + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + rule_fsts=rule_fsts, + rule_fars=rule_fars, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + + def create_stream(self, hotwords: Optional[str] = None): + if hotwords is None: + return self.recognizer.create_stream() + else: + return self.recognizer.create_stream(hotwords) + + def decode_stream(self, s: OnlineStream): + self.recognizer.decode_stream(s) + + def decode_streams(self, ss: List[OnlineStream]): + self.recognizer.decode_streams(ss) + + def is_ready(self, s: OnlineStream) -> bool: + return self.recognizer.is_ready(s) + + def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult: + return self.recognizer.get_result(s) + + def get_result(self, s: OnlineStream) -> str: + return self.recognizer.get_result(s).text.strip() + + def get_result_as_json_string(self, s: OnlineStream) -> str: + return self.recognizer.get_result(s).as_json_string() + + def tokens(self, s: OnlineStream) -> List[str]: + return self.recognizer.get_result(s).tokens + + def timestamps(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).timestamps + + def start_time(self, s: OnlineStream) -> float: + return self.recognizer.get_result(s).start_time + + def ys_probs(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).ys_probs + + def lm_probs(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).lm_probs + + def context_scores(self, s: OnlineStream) -> List[float]: + return self.recognizer.get_result(s).context_scores + + def is_endpoint(self, s: OnlineStream) -> bool: + return self.recognizer.is_endpoint(s) + + def reset(self, s: OnlineStream) -> bool: + return self.recognizer.reset(s) From dc8bfb26c749fe2f21eb0ac7b33edf628ea380d9 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Wed, 3 Jul 2024 21:01:37 +0000 Subject: [PATCH 27/34] lint fix --- sherpa-onnx/csrc/provider-config.h | 2 +- sherpa-onnx/python/csrc/cuda-config.h | 2 +- sherpa-onnx/python/csrc/tensorrt-config.cc | 67 ++++------------------ sherpa-onnx/python/csrc/tensorrt-config.h | 2 +- 4 files changed, 14 insertions(+), 59 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index c3837e3b7..7ec154bf4 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -17,7 +17,7 @@ struct CudaConfig { int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic; CudaConfig() = default; - CudaConfig(int32_t cudnn_conv_algo_search) + explicit CudaConfig(int32_t cudnn_conv_algo_search) : cudnn_conv_algo_search(cudnn_conv_algo_search) {} void Register(ParseOptions *po); diff --git a/sherpa-onnx/python/csrc/cuda-config.h b/sherpa-onnx/python/csrc/cuda-config.h index 4af175800..012fb29eb 100644 --- a/sherpa-onnx/python/csrc/cuda-config.h +++ b/sherpa-onnx/python/csrc/cuda-config.h @@ -13,4 +13,4 @@ void PybindCudaConfig(py::module *m); } -#endif // SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ +#endif // SHERPA_ONNX_PYTHON_CSRC_CUDA_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/tensorrt-config.cc b/sherpa-onnx/python/csrc/tensorrt-config.cc index 4fcfb3dae..37b3cad2c 100644 --- a/sherpa-onnx/python/csrc/tensorrt-config.cc +++ b/sherpa-onnx/python/csrc/tensorrt-config.cc @@ -5,7 +5,7 @@ #include "sherpa-onnx/python/csrc/tensorrt-config.h" #include - +#include #include "sherpa-onnx/csrc/provider-config.h" namespace sherpa_onnx { @@ -50,13 +50,18 @@ void PybindTensorrtConfig(py::module *m) { py::arg("trt_timing_cache_path") = ".", py::arg("trt_dump_subgraphs") = false) - .def_readwrite("trt_max_workspace_size", &PyClass::trt_max_workspace_size) - .def_readwrite("trt_max_partition_iterations", &PyClass::trt_max_partition_iterations) + .def_readwrite("trt_max_workspace_size", + &PyClass::trt_max_workspace_size) + .def_readwrite("trt_max_partition_iterations", + &PyClass::trt_max_partition_iterations) .def_readwrite("trt_min_subgraph_size", &PyClass::trt_min_subgraph_size) .def_readwrite("trt_fp16_enable", &PyClass::trt_fp16_enable) - .def_readwrite("trt_detailed_build_log", &PyClass::trt_detailed_build_log) - .def_readwrite("trt_engine_cache_enable", &PyClass::trt_engine_cache_enable) - .def_readwrite("trt_timing_cache_enable", &PyClass::trt_timing_cache_enable) + .def_readwrite("trt_detailed_build_log", + &PyClass::trt_detailed_build_log) + .def_readwrite("trt_engine_cache_enable", + &PyClass::trt_engine_cache_enable) + .def_readwrite("trt_timing_cache_enable", + &PyClass::trt_timing_cache_enable) .def_readwrite("trt_engine_cache_path", &PyClass::trt_engine_cache_path) .def_readwrite("trt_timing_cache_path", &PyClass::trt_timing_cache_path) .def_readwrite("trt_dump_subgraphs", &PyClass::trt_dump_subgraphs) @@ -65,53 +70,3 @@ void PybindTensorrtConfig(py::module *m) { } } // namespace sherpa_onnx - - // .def(py::init([](int32_t trt_max_workspace_size, - // int32_t trt_max_partition_iterations, - // int32_t trt_min_subgraph_size, - // bool trt_fp16_enable, - // bool trt_detailed_build_log, - // bool trt_engine_cache_enable, - // bool trt_timing_cache_enable, - // const std::string &trt_engine_cache_path, - // const std::string &trt_timing_cache_path, - // bool trt_dump_subgraphs) -> std::unique_ptr { - // auto ans = std::make_unique(); - - // ans->trt_max_workspace_size = trt_max_workspace_size; - // ans->trt_max_partition_iterations = trt_max_partition_iterations; - // ans->trt_min_subgraph_size = trt_min_subgraph_size; - // ans->trt_fp16_enable = trt_fp16_enable; - // ans->trt_detailed_build_log = trt_detailed_build_log; - // ans->trt_engine_cache_enable = trt_engine_cache_enable; - // ans->trt_timing_cache_enable = trt_timing_cache_enable; - // ans->trt_engine_cache_path = trt_engine_cache_path; - // ans->trt_timing_cache_path = trt_timing_cache_path; - // ans->trt_dump_subgraphs = trt_dump_subgraphs; - - // return ans; - // }), - // py::arg("trt_max_workspace_size") = 2147483648, - // py::arg("trt_max_partition_iterations") = 10, - // py::arg("trt_min_subgraph_size") = 5, - // py::arg("trt_fp16_enable") = true, - // py::arg("trt_detailed_build_log") = false, - // py::arg("trt_engine_cache_enable") = true, - // py::arg("trt_timing_cache_enable") = true, - // py::arg("trt_engine_cache_path") = ".", - // py::arg("trt_timing_cache_path") = ".", - // py::arg("trt_dump_subgraphs") = false) - - // .def(py::init(), - // py::arg("trt_max_workspace_size"), - // py::arg("trt_max_partition_iterations"), - // py::arg("trt_min_subgraph_size"), - // py::arg("trt_fp16_enable"), - // py::arg("trt_detailed_build_log"), - // py::arg("trt_engine_cache_enable"), - // py::arg("trt_timing_cache_enable"), - // py::arg("trt_engine_cache_path"), - // py::arg("trt_timing_cache_path"), - // py::arg("trt_dump_subgraphs")) \ No newline at end of file diff --git a/sherpa-onnx/python/csrc/tensorrt-config.h b/sherpa-onnx/python/csrc/tensorrt-config.h index b68ae2324..d8eea7000 100644 --- a/sherpa-onnx/python/csrc/tensorrt-config.h +++ b/sherpa-onnx/python/csrc/tensorrt-config.h @@ -13,4 +13,4 @@ void PybindTensorrtConfig(py::module *m); } -#endif // SHERPA_ONNX_PYTHON_CSRC_PROVIDER_CONFIG_H_ +#endif // SHERPA_ONNX_PYTHON_CSRC_TENSORRT_CONFIG_H_ From 0982ee7356b94f905616f94c428ae9c357a64b44 Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Thu, 4 Jul 2024 08:05:56 +0530 Subject: [PATCH 28/34] Update sherpa-onnx/python/sherpa_onnx/online_recognizer.py trt workspace update Co-authored-by: Fangjun Kuang --- sherpa-onnx/python/sherpa_onnx/online_recognizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index f73592382..779ba6e77 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -71,7 +71,7 @@ def from_transducer( provider: str = "cpu", device: int = 0, cudnn_conv_algo_search: int = 1, - trt_max_workspace_size: int = 2147483648, + trt_max_workspace_size: int = 2147483647, trt_max_partition_iterations: int = 10, trt_min_subgraph_size: int = 5, trt_fp16_enable: bool = True, From 4e2ede0a2507cc98b5d12f8ffdf73399776df2ba Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 4 Jul 2024 03:03:48 +0000 Subject: [PATCH 29/34] clean up Signed-off-by: manickavela1998@gmail.com --- .../sherpa_onnx/online_recognizer_ori.py | 746 ------------------ 1 file changed, 746 deletions(-) delete mode 100644 sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py deleted file mode 100644 index cc71689d5..000000000 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer_ori.py +++ /dev/null @@ -1,746 +0,0 @@ -# Copyright (c) 2023 Xiaomi Corporation -from pathlib import Path -from typing import List, Optional - -from _sherpa_onnx import ( - EndpointConfig, - FeatureExtractorConfig, - OnlineLMConfig, - - OnlineModelConfig, - OnlineParaformerModelConfig, -) -from _sherpa_onnx import OnlineRecognizer as _Recognizer -from _sherpa_onnx import ( - OnlineRecognizerConfig, - OnlineRecognizerResult, - OnlineStream, - OnlineTransducerModelConfig, - OnlineWenetCtcModelConfig, - OnlineNeMoCtcModelConfig, - OnlineZipformer2CtcModelConfig, - OnlineCtcFstDecoderConfig, -) - - -def _assert_file_exists(f: str): - assert Path(f).is_file(), f"{f} does not exist" - - -class OnlineRecognizer(object): - """A class for streaming speech recognition. - - Please refer to the following files for usages - - https://github.com/k2-fsa/sherpa-onnx/blob/master/sherpa-onnx/python/tests/test_online_recognizer.py - - https://github.com/k2-fsa/sherpa-onnx/blob/master/python-api-examples/online-decode-files.py - """ - - @classmethod - def from_transducer( - cls, - tokens: str, - encoder: str, - decoder: str, - joiner: str, - num_threads: int = 2, - sample_rate: float = 16000, - feature_dim: int = 80, - low_freq: float = 20.0, - high_freq: float = -400.0, - dither: float = 0.0, - enable_endpoint_detection: bool = False, - rule1_min_trailing_silence: float = 2.4, - rule2_min_trailing_silence: float = 1.2, - rule3_min_utterance_length: float = 20.0, - decoding_method: str = "greedy_search", - max_active_paths: int = 4, - hotwords_score: float = 1.5, - blank_penalty: float = 0.0, - hotwords_file: str = "", - provider: str = "cpu", - model_type: str = "", - modeling_unit: str = "cjkchar", - bpe_vocab: str = "", - lm: str = "", - lm_scale: float = 0.1, - temperature_scale: float = 2.0, - debug: bool = False, - rule_fsts: str = "", - rule_fars: str = "", - ): - """ - Please refer to - ``_ - to download pre-trained models for different languages, e.g., Chinese, - English, etc. - - Args: - tokens: - Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two - columns:: - - symbol integer_id - - encoder: - Path to ``encoder.onnx``. - decoder: - Path to ``decoder.onnx``. - joiner: - Path to ``joiner.onnx``. - num_threads: - Number of threads for neural network computation. - sample_rate: - Sample rate of the training data used to train the model. - feature_dim: - Dimension of the feature used to train the model. - low_freq: - Low cutoff frequency for mel bins in feature extraction. - high_freq: - High cutoff frequency for mel bins in feature extraction - (if <= 0, offset from Nyquist) - dither: - Dithering constant (0.0 means no dither). - By default the audio samples are in range [-1,+1], - so dithering constant 0.00003 is a good value, - equivalent to the default 1.0 from kaldi - enable_endpoint_detection: - True to enable endpoint detection. False to disable endpoint - detection. - rule1_min_trailing_silence: - Used only when enable_endpoint_detection is True. If the duration - of trailing silence in seconds is larger than this value, we assume - an endpoint is detected. - rule2_min_trailing_silence: - Used only when enable_endpoint_detection is True. If we have decoded - something that is nonsilence and if the duration of trailing silence - in seconds is larger than this value, we assume an endpoint is - detected. - rule3_min_utterance_length: - Used only when enable_endpoint_detection is True. If the utterance - length in seconds is larger than this value, we assume an endpoint - is detected. - decoding_method: - Valid values are greedy_search, modified_beam_search. - max_active_paths: - Use only when decoding_method is modified_beam_search. It specifies - the maximum number of active paths during beam search. - blank_penalty: - The penalty applied on blank symbol during decoding. - hotwords_file: - The file containing hotwords, one words/phrases per line, and for each - phrase the bpe/cjkchar are separated by a space. - hotwords_score: - The hotword score of each token for biasing word/phrase. Used only if - hotwords_file is given with modified_beam_search as decoding method. - temperature_scale: - Temperature scaling for output symbol confidence estiamation. - It affects only confidence values, the decoding uses the original - logits without temperature. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. - model_type: - Online transducer model type. Valid values are: conformer, lstm, - zipformer, zipformer2. All other values lead to loading the model twice. - modeling_unit: - The modeling unit of the model, commonly used units are bpe, cjkchar, - cjkchar+bpe, etc. Currently, it is needed only when hotwords are - provided, we need it to encode the hotwords into token sequence. - bpe_vocab: - The vocabulary generated by google's sentencepiece program. - It is a file has two columns, one is the token, the other is - the log probability, you can get it from the directory where - your bpe model is generated. Only used when hotwords provided - and the modeling unit is bpe or cjkchar+bpe. - rule_fsts: - If not empty, it specifies fsts for inverse text normalization. - If there are multiple fsts, they are separated by a comma. - rule_fars: - If not empty, it specifies fst archives for inverse text normalization. - If there are multiple archives, they are separated by a comma. - """ - self = cls.__new__(cls) - _assert_file_exists(tokens) - _assert_file_exists(encoder) - _assert_file_exists(decoder) - _assert_file_exists(joiner) - - assert num_threads > 0, num_threads - - transducer_config = OnlineTransducerModelConfig( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - - model_config = OnlineModelConfig( - transducer=transducer_config, - tokens=tokens, - num_threads=num_threads, - provider=provider, - model_type=model_type, - modeling_unit=modeling_unit, - bpe_vocab=bpe_vocab, - debug=debug, - ) - - feat_config = FeatureExtractorConfig( - sampling_rate=sample_rate, - feature_dim=feature_dim, - low_freq=low_freq, - high_freq=high_freq, - dither=dither, - ) - - endpoint_config = EndpointConfig( - rule1_min_trailing_silence=rule1_min_trailing_silence, - rule2_min_trailing_silence=rule2_min_trailing_silence, - rule3_min_utterance_length=rule3_min_utterance_length, - ) - - if len(hotwords_file) > 0 and decoding_method != "modified_beam_search": - raise ValueError( - "Please use --decoding-method=modified_beam_search when using " - f"--hotwords-file. Currently given: {decoding_method}" - ) - - if lm and decoding_method != "modified_beam_search": - raise ValueError( - "Please use --decoding-method=modified_beam_search when using " - f"--lm. Currently given: {decoding_method}" - ) - - lm_config = OnlineLMConfig( - model=lm, - scale=lm_scale, - ) - - recognizer_config = OnlineRecognizerConfig( - feat_config=feat_config, - model_config=model_config, - lm_config=lm_config, - endpoint_config=endpoint_config, - enable_endpoint=enable_endpoint_detection, - decoding_method=decoding_method, - max_active_paths=max_active_paths, - hotwords_score=hotwords_score, - hotwords_file=hotwords_file, - blank_penalty=blank_penalty, - temperature_scale=temperature_scale, - rule_fsts=rule_fsts, - rule_fars=rule_fars, - ) - - self.recognizer = _Recognizer(recognizer_config) - self.config = recognizer_config - return self - - @classmethod - def from_paraformer( - cls, - tokens: str, - encoder: str, - decoder: str, - num_threads: int = 2, - sample_rate: float = 16000, - feature_dim: int = 80, - enable_endpoint_detection: bool = False, - rule1_min_trailing_silence: float = 2.4, - rule2_min_trailing_silence: float = 1.2, - rule3_min_utterance_length: float = 20.0, - decoding_method: str = "greedy_search", - provider: str = "cpu", - debug: bool = False, - rule_fsts: str = "", - rule_fars: str = "", - ): - """ - Please refer to - ``_ - to download pre-trained models for different languages, e.g., Chinese, - English, etc. - - Args: - tokens: - Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two - columns:: - - symbol integer_id - - encoder: - Path to ``encoder.onnx``. - decoder: - Path to ``decoder.onnx``. - num_threads: - Number of threads for neural network computation. - sample_rate: - Sample rate of the training data used to train the model. - feature_dim: - Dimension of the feature used to train the model. - enable_endpoint_detection: - True to enable endpoint detection. False to disable endpoint - detection. - rule1_min_trailing_silence: - Used only when enable_endpoint_detection is True. If the duration - of trailing silence in seconds is larger than this value, we assume - an endpoint is detected. - rule2_min_trailing_silence: - Used only when enable_endpoint_detection is True. If we have decoded - something that is nonsilence and if the duration of trailing silence - in seconds is larger than this value, we assume an endpoint is - detected. - rule3_min_utterance_length: - Used only when enable_endpoint_detection is True. If the utterance - length in seconds is larger than this value, we assume an endpoint - is detected. - decoding_method: - The only valid value is greedy_search. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. - rule_fsts: - If not empty, it specifies fsts for inverse text normalization. - If there are multiple fsts, they are separated by a comma. - rule_fars: - If not empty, it specifies fst archives for inverse text normalization. - If there are multiple archives, they are separated by a comma. - """ - self = cls.__new__(cls) - _assert_file_exists(tokens) - _assert_file_exists(encoder) - _assert_file_exists(decoder) - - assert num_threads > 0, num_threads - - paraformer_config = OnlineParaformerModelConfig( - encoder=encoder, - decoder=decoder, - ) - - model_config = OnlineModelConfig( - paraformer=paraformer_config, - tokens=tokens, - num_threads=num_threads, - provider=provider, - model_type="paraformer", - debug=debug, - ) - - feat_config = FeatureExtractorConfig( - sampling_rate=sample_rate, - feature_dim=feature_dim, - ) - - endpoint_config = EndpointConfig( - rule1_min_trailing_silence=rule1_min_trailing_silence, - rule2_min_trailing_silence=rule2_min_trailing_silence, - rule3_min_utterance_length=rule3_min_utterance_length, - ) - - recognizer_config = OnlineRecognizerConfig( - feat_config=feat_config, - model_config=model_config, - endpoint_config=endpoint_config, - enable_endpoint=enable_endpoint_detection, - decoding_method=decoding_method, - rule_fsts=rule_fsts, - rule_fars=rule_fars, - ) - - self.recognizer = _Recognizer(recognizer_config) - self.config = recognizer_config - return self - - @classmethod - def from_zipformer2_ctc( - cls, - tokens: str, - model: str, - num_threads: int = 2, - sample_rate: float = 16000, - feature_dim: int = 80, - enable_endpoint_detection: bool = False, - rule1_min_trailing_silence: float = 2.4, - rule2_min_trailing_silence: float = 1.2, - rule3_min_utterance_length: float = 20.0, - decoding_method: str = "greedy_search", - ctc_graph: str = "", - ctc_max_active: int = 3000, - provider: str = "cpu", - debug: bool = False, - rule_fsts: str = "", - rule_fars: str = "", - ): - """ - Please refer to - ``_ - to download pre-trained models for different languages, e.g., Chinese, - English, etc. - - Args: - tokens: - Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two - columns:: - - symbol integer_id - - model: - Path to ``model.onnx``. - num_threads: - Number of threads for neural network computation. - sample_rate: - Sample rate of the training data used to train the model. - feature_dim: - Dimension of the feature used to train the model. - enable_endpoint_detection: - True to enable endpoint detection. False to disable endpoint - detection. - rule1_min_trailing_silence: - Used only when enable_endpoint_detection is True. If the duration - of trailing silence in seconds is larger than this value, we assume - an endpoint is detected. - rule2_min_trailing_silence: - Used only when enable_endpoint_detection is True. If we have decoded - something that is nonsilence and if the duration of trailing silence - in seconds is larger than this value, we assume an endpoint is - detected. - rule3_min_utterance_length: - Used only when enable_endpoint_detection is True. If the utterance - length in seconds is larger than this value, we assume an endpoint - is detected. - decoding_method: - The only valid value is greedy_search. - ctc_graph: - If not empty, decoding_method is ignored. It contains the path to - H.fst, HL.fst, or HLG.fst - ctc_max_active: - Used only when ctc_graph is not empty. It specifies the maximum - active paths at a time. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. - rule_fsts: - If not empty, it specifies fsts for inverse text normalization. - If there are multiple fsts, they are separated by a comma. - rule_fars: - If not empty, it specifies fst archives for inverse text normalization. - If there are multiple archives, they are separated by a comma. - """ - self = cls.__new__(cls) - _assert_file_exists(tokens) - _assert_file_exists(model) - - assert num_threads > 0, num_threads - - zipformer2_ctc_config = OnlineZipformer2CtcModelConfig(model=model) - - model_config = OnlineModelConfig( - zipformer2_ctc=zipformer2_ctc_config, - tokens=tokens, - num_threads=num_threads, - provider=provider, - debug=debug, - ) - - feat_config = FeatureExtractorConfig( - sampling_rate=sample_rate, - feature_dim=feature_dim, - ) - - endpoint_config = EndpointConfig( - rule1_min_trailing_silence=rule1_min_trailing_silence, - rule2_min_trailing_silence=rule2_min_trailing_silence, - rule3_min_utterance_length=rule3_min_utterance_length, - ) - - ctc_fst_decoder_config = OnlineCtcFstDecoderConfig( - graph=ctc_graph, - max_active=ctc_max_active, - ) - - recognizer_config = OnlineRecognizerConfig( - feat_config=feat_config, - model_config=model_config, - endpoint_config=endpoint_config, - ctc_fst_decoder_config=ctc_fst_decoder_config, - enable_endpoint=enable_endpoint_detection, - decoding_method=decoding_method, - rule_fsts=rule_fsts, - rule_fars=rule_fars, - ) - - self.recognizer = _Recognizer(recognizer_config) - self.config = recognizer_config - return self - - @classmethod - def from_nemo_ctc( - cls, - tokens: str, - model: str, - num_threads: int = 2, - sample_rate: float = 16000, - feature_dim: int = 80, - enable_endpoint_detection: bool = False, - rule1_min_trailing_silence: float = 2.4, - rule2_min_trailing_silence: float = 1.2, - rule3_min_utterance_length: float = 20.0, - decoding_method: str = "greedy_search", - provider: str = "cpu", - debug: bool = False, - rule_fsts: str = "", - rule_fars: str = "", - ): - """ - Please refer to - ``_ - to download pre-trained models. - - Args: - tokens: - Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two - columns:: - - symbol integer_id - - model: - Path to ``model.onnx``. - num_threads: - Number of threads for neural network computation. - sample_rate: - Sample rate of the training data used to train the model. - feature_dim: - Dimension of the feature used to train the model. - enable_endpoint_detection: - True to enable endpoint detection. False to disable endpoint - detection. - rule1_min_trailing_silence: - Used only when enable_endpoint_detection is True. If the duration - of trailing silence in seconds is larger than this value, we assume - an endpoint is detected. - rule2_min_trailing_silence: - Used only when enable_endpoint_detection is True. If we have decoded - something that is nonsilence and if the duration of trailing silence - in seconds is larger than this value, we assume an endpoint is - detected. - rule3_min_utterance_length: - Used only when enable_endpoint_detection is True. If the utterance - length in seconds is larger than this value, we assume an endpoint - is detected. - decoding_method: - The only valid value is greedy_search. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. - debug: - True to show meta data in the model. - rule_fsts: - If not empty, it specifies fsts for inverse text normalization. - If there are multiple fsts, they are separated by a comma. - rule_fars: - If not empty, it specifies fst archives for inverse text normalization. - If there are multiple archives, they are separated by a comma. - """ - self = cls.__new__(cls) - _assert_file_exists(tokens) - _assert_file_exists(model) - - assert num_threads > 0, num_threads - - nemo_ctc_config = OnlineNeMoCtcModelConfig( - model=model, - ) - - model_config = OnlineModelConfig( - nemo_ctc=nemo_ctc_config, - tokens=tokens, - num_threads=num_threads, - provider=provider, - debug=debug, - ) - - feat_config = FeatureExtractorConfig( - sampling_rate=sample_rate, - feature_dim=feature_dim, - ) - - endpoint_config = EndpointConfig( - rule1_min_trailing_silence=rule1_min_trailing_silence, - rule2_min_trailing_silence=rule2_min_trailing_silence, - rule3_min_utterance_length=rule3_min_utterance_length, - ) - - recognizer_config = OnlineRecognizerConfig( - feat_config=feat_config, - model_config=model_config, - endpoint_config=endpoint_config, - enable_endpoint=enable_endpoint_detection, - decoding_method=decoding_method, - rule_fsts=rule_fsts, - rule_fars=rule_fars, - ) - - self.recognizer = _Recognizer(recognizer_config) - self.config = recognizer_config - return self - - @classmethod - def from_wenet_ctc( - cls, - tokens: str, - model: str, - chunk_size: int = 16, - num_left_chunks: int = 4, - num_threads: int = 2, - sample_rate: float = 16000, - feature_dim: int = 80, - enable_endpoint_detection: bool = False, - rule1_min_trailing_silence: float = 2.4, - rule2_min_trailing_silence: float = 1.2, - rule3_min_utterance_length: float = 20.0, - decoding_method: str = "greedy_search", - provider: str = "cpu", - debug: bool = False, - rule_fsts: str = "", - rule_fars: str = "", - ): - """ - Please refer to - ``_ - to download pre-trained models for different languages, e.g., Chinese, - English, etc. - - Args: - tokens: - Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two - columns:: - - symbol integer_id - - model: - Path to ``model.onnx``. - chunk_size: - The --chunk-size parameter from WeNet. - num_left_chunks: - The --num-left-chunks parameter from WeNet. - num_threads: - Number of threads for neural network computation. - sample_rate: - Sample rate of the training data used to train the model. - feature_dim: - Dimension of the feature used to train the model. - enable_endpoint_detection: - True to enable endpoint detection. False to disable endpoint - detection. - rule1_min_trailing_silence: - Used only when enable_endpoint_detection is True. If the duration - of trailing silence in seconds is larger than this value, we assume - an endpoint is detected. - rule2_min_trailing_silence: - Used only when enable_endpoint_detection is True. If we have decoded - something that is nonsilence and if the duration of trailing silence - in seconds is larger than this value, we assume an endpoint is - detected. - rule3_min_utterance_length: - Used only when enable_endpoint_detection is True. If the utterance - length in seconds is larger than this value, we assume an endpoint - is detected. - decoding_method: - The only valid value is greedy_search. - provider: - onnxruntime execution providers. Valid values are: cpu, cuda, coreml. - rule_fsts: - If not empty, it specifies fsts for inverse text normalization. - If there are multiple fsts, they are separated by a comma. - rule_fars: - If not empty, it specifies fst archives for inverse text normalization. - If there are multiple archives, they are separated by a comma. - """ - self = cls.__new__(cls) - _assert_file_exists(tokens) - _assert_file_exists(model) - - assert num_threads > 0, num_threads - - wenet_ctc_config = OnlineWenetCtcModelConfig( - model=model, - chunk_size=chunk_size, - num_left_chunks=num_left_chunks, - ) - - model_config = OnlineModelConfig( - wenet_ctc=wenet_ctc_config, - tokens=tokens, - num_threads=num_threads, - provider=provider, - debug=debug, - ) - - feat_config = FeatureExtractorConfig( - sampling_rate=sample_rate, - feature_dim=feature_dim, - ) - - endpoint_config = EndpointConfig( - rule1_min_trailing_silence=rule1_min_trailing_silence, - rule2_min_trailing_silence=rule2_min_trailing_silence, - rule3_min_utterance_length=rule3_min_utterance_length, - ) - - recognizer_config = OnlineRecognizerConfig( - feat_config=feat_config, - model_config=model_config, - endpoint_config=endpoint_config, - enable_endpoint=enable_endpoint_detection, - decoding_method=decoding_method, - rule_fsts=rule_fsts, - rule_fars=rule_fars, - ) - - self.recognizer = _Recognizer(recognizer_config) - self.config = recognizer_config - return self - - def create_stream(self, hotwords: Optional[str] = None): - if hotwords is None: - return self.recognizer.create_stream() - else: - return self.recognizer.create_stream(hotwords) - - def decode_stream(self, s: OnlineStream): - self.recognizer.decode_stream(s) - - def decode_streams(self, ss: List[OnlineStream]): - self.recognizer.decode_streams(ss) - - def is_ready(self, s: OnlineStream) -> bool: - return self.recognizer.is_ready(s) - - def get_result_all(self, s: OnlineStream) -> OnlineRecognizerResult: - return self.recognizer.get_result(s) - - def get_result(self, s: OnlineStream) -> str: - return self.recognizer.get_result(s).text.strip() - - def get_result_as_json_string(self, s: OnlineStream) -> str: - return self.recognizer.get_result(s).as_json_string() - - def tokens(self, s: OnlineStream) -> List[str]: - return self.recognizer.get_result(s).tokens - - def timestamps(self, s: OnlineStream) -> List[float]: - return self.recognizer.get_result(s).timestamps - - def start_time(self, s: OnlineStream) -> float: - return self.recognizer.get_result(s).start_time - - def ys_probs(self, s: OnlineStream) -> List[float]: - return self.recognizer.get_result(s).ys_probs - - def lm_probs(self, s: OnlineStream) -> List[float]: - return self.recognizer.get_result(s).lm_probs - - def context_scores(self, s: OnlineStream) -> List[float]: - return self.recognizer.get_result(s).context_scores - - def is_endpoint(self, s: OnlineStream) -> bool: - return self.recognizer.is_endpoint(s) - - def reset(self, s: OnlineStream) -> bool: - return self.recognizer.reset(s) From 6ab0567b5b0d4856fa44eeb2d99f9bdb2aece1c2 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 4 Jul 2024 04:14:42 +0000 Subject: [PATCH 30/34] fixing keyword spotter Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/sherpa_onnx/keyword_spotter.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index 218628ea9..c334fb225 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -9,6 +9,7 @@ OnlineModelConfig, OnlineTransducerModelConfig, OnlineStream, + ProviderConfig, ) from _sherpa_onnx import KeywordSpotter as _KeywordSpotter @@ -41,6 +42,7 @@ def __init__( keywords_threshold: float = 0.25, num_trailing_blanks: int = 1, provider: str = "cpu", + device: int = 1, ): """ Please refer to @@ -85,6 +87,9 @@ def __init__( between each other. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + device: + onnxruntime cuda device index. + """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -99,11 +104,15 @@ def __init__( joiner=joiner, ) + provider_config = ProviderConfig( + provider=provider, + device = device, + ) model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, num_threads=num_threads, - provider=provider, + provider_config=provider_config, ) feat_config = FeatureExtractorConfig( From 04a9d8db8762339c1b492d31e46dc55f332796a7 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 4 Jul 2024 06:53:36 +0000 Subject: [PATCH 31/34] keyword device Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/python/sherpa_onnx/keyword_spotter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index c334fb225..66d716984 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -42,7 +42,7 @@ def __init__( keywords_threshold: float = 0.25, num_trailing_blanks: int = 1, provider: str = "cpu", - device: int = 1, + device: int = 0, ): """ Please refer to @@ -89,7 +89,6 @@ def __init__( onnxruntime execution providers. Valid values are: cpu, cuda, coreml. device: onnxruntime cuda device index. - """ _assert_file_exists(tokens) _assert_file_exists(encoder) @@ -108,6 +107,7 @@ def __init__( provider=provider, device = device, ) + model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, From 1a84eaf66498f711ed1b4cf1d45a5a66181e80aa Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 4 Jul 2024 12:34:04 +0000 Subject: [PATCH 32/34] clean up Signed-off-by: manickavela1998@gmail.com --- sherpa-onnx/csrc/provider-config.cc | 8 +++++--- sherpa-onnx/python/csrc/sherpa-onnx.h | 2 -- sherpa-onnx/python/sherpa_onnx/__init__.py | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 2d4109040..422130239 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -113,9 +113,11 @@ void ProviderConfig::Register(ParseOptions *po) { } bool ProviderConfig::Validate() const { - if (device < 0) { - SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); - return false; + if(provider == "cuda" || provider == "trt") { + if (device < 0) { + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); + return false; + } } return true; } diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.h b/sherpa-onnx/python/csrc/sherpa-onnx.h index 54fd4e88f..7bce9f49b 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.h +++ b/sherpa-onnx/python/csrc/sherpa-onnx.h @@ -5,8 +5,6 @@ #ifndef SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ #define SHERPA_ONNX_PYTHON_CSRC_SHERPA_ONNX_H_ -#define PYBIND11_DETAILED_ERROR_MESSAGES - #include "pybind11/functional.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index faccfe3f5..7a832ba06 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -16,7 +16,6 @@ OfflineTtsVitsModelConfig, OfflineZipformerAudioTaggingModelConfig, OnlineStream, - ProviderConfig, SileroVadModelConfig, SpeakerEmbeddingExtractor, SpeakerEmbeddingExtractorConfig, From dfe9a190c9e425b3e0b1cff8372a535e1b282b81 Mon Sep 17 00:00:00 2001 From: "manickavela1998@gmail.com" Date: Thu, 4 Jul 2024 14:35:35 +0000 Subject: [PATCH 33/34] update condition and fix --- sherpa-onnx/csrc/provider-config.cc | 17 ++++++++++++----- sherpa-onnx/csrc/provider-config.h | 2 +- sherpa-onnx/python/csrc/tensorrt-config.cc | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index 422130239..dfc1c00e2 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -113,12 +113,19 @@ void ProviderConfig::Register(ParseOptions *po) { } bool ProviderConfig::Validate() const { - if(provider == "cuda" || provider == "trt") { - if (device < 0) { - SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); - return false; - } + if (device < 0) { + SHERPA_ONNX_LOGE("device: '%d' is invalid.", device); + return false; + } + + if(provider == "cuda" && !cuda_config.Validate()) { + return false; } + + if(provider == "trt" && !trt_config.Validate()) { + return false; + } + return true; } diff --git a/sherpa-onnx/csrc/provider-config.h b/sherpa-onnx/csrc/provider-config.h index 7ec154bf4..ff9607909 100644 --- a/sherpa-onnx/csrc/provider-config.h +++ b/sherpa-onnx/csrc/provider-config.h @@ -27,7 +27,7 @@ struct CudaConfig { }; struct TensorrtConfig { - int32_t trt_max_workspace_size = 2147483648; + int32_t trt_max_workspace_size = 2147483647; int32_t trt_max_partition_iterations = 10; int32_t trt_min_subgraph_size = 5; bool trt_fp16_enable = true; diff --git a/sherpa-onnx/python/csrc/tensorrt-config.cc b/sherpa-onnx/python/csrc/tensorrt-config.cc index 37b3cad2c..87962a2d3 100644 --- a/sherpa-onnx/python/csrc/tensorrt-config.cc +++ b/sherpa-onnx/python/csrc/tensorrt-config.cc @@ -39,7 +39,7 @@ void PybindTensorrtConfig(py::module *m) { return ans; }), - py::arg("trt_max_workspace_size") = 2147483648, + py::arg("trt_max_workspace_size") = 2147483647, py::arg("trt_max_partition_iterations") = 10, py::arg("trt_min_subgraph_size") = 5, py::arg("trt_fp16_enable") = true, From 4cc82fc3ce5c921e441a7224fae7d9aa485b974f Mon Sep 17 00:00:00 2001 From: Manix <50542248+manickavela29@users.noreply.github.com> Date: Fri, 5 Jul 2024 09:12:57 +0530 Subject: [PATCH 34/34] Update provider-config.cc --- sherpa-onnx/csrc/provider-config.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/csrc/provider-config.cc b/sherpa-onnx/csrc/provider-config.cc index dfc1c00e2..8a58746c7 100644 --- a/sherpa-onnx/csrc/provider-config.cc +++ b/sherpa-onnx/csrc/provider-config.cc @@ -118,11 +118,11 @@ bool ProviderConfig::Validate() const { return false; } - if(provider == "cuda" && !cuda_config.Validate()) { + if (provider == "cuda" && !cuda_config.Validate()) { return false; } - if(provider == "trt" && !trt_config.Validate()) { + if (provider == "trt" && !trt_config.Validate()) { return false; }