Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnxrt execution provider configs #992

Merged
merged 34 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3c1879d
onnxrt ep configs
manickavela29 Jun 6, 2024
e69d495
updating config
manickavela29 Jun 24, 2024
47cdf99
updating python api
manickavela29 Jun 24, 2024
e29ff24
attempting fixing pybind
manickavela29 Jun 24, 2024
d139c1d
patch for pybind and clean
manickavela29 Jun 24, 2024
79f45ad
clean up
manickavela-uni Jun 25, 2024
207f4a1
python api complete
manickavela29 Jun 25, 2024
df2e2e1
Apply suggestions from code review
manickavela29 Jun 26, 2024
c477339
fix and update
manickavela29 Jun 26, 2024
f20157c
lint fix
manickavela29 Jun 26, 2024
22d73e5
device_id fix
manickavela29 Jun 26, 2024
e63dc19
tidy-clang and lint
manickavela29 Jun 26, 2024
021dc01
Update sherpa-onnx/python/csrc/provider-config.cc
manickavela29 Jun 26, 2024
9c6be4f
updating python api cmake
manickavela29 Jun 26, 2024
0022ae4
pybind fix
manickavela29 Jun 26, 2024
b4f1985
fix int32_t
manickavela29 Jun 26, 2024
4011948
uint32_t back and jni fix
manickavela29 Jun 26, 2024
1cfd469
JNI and python fix
manickavela29 Jun 26, 2024
2ef7a7c
Update sherpa-onnx/csrc/provider-config.cc
manickavela29 Jun 26, 2024
fa001b7
removing from offline
manickavela29 Jun 26, 2024
be3aa27
Apply suggestions from code review
manickavela29 Jun 27, 2024
e2dc60c
bug fix
manickavela29 Jun 27, 2024
4f5a58f
handling uint and attempting python fix
manickavela-uni Jun 28, 2024
389cbf5
pybind attempt
manickavela29 Jun 28, 2024
bbfea38
python-dump
manickavela29 Jun 29, 2024
2f188bb
python interface
manickavela29 Jul 3, 2024
dc8bfb2
lint fix
manickavela29 Jul 3, 2024
0982ee7
Update sherpa-onnx/python/sherpa_onnx/online_recognizer.py
manickavela29 Jul 4, 2024
4e2ede0
clean up
manickavela29 Jul 4, 2024
6ab0567
fixing keyword spotter
manickavela29 Jul 4, 2024
04a9d8d
keyword device
manickavela29 Jul 4, 2024
1a84eaf
clean up
manickavela29 Jul 4, 2024
dfe9a19
update condition and fix
manickavela-uni Jul 4, 2024
4cc82fc
Update provider-config.cc
manickavela29 Jul 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider =
recognizer_config.model_config.provider_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
recognizer_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
Expand Down Expand Up @@ -570,7 +570,7 @@ SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
SHERPA_ONNX_OR(config->model_config.tokens, "");
spotter_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
spotter_config.model_config.provider =
spotter_config.model_config.provider_config.provider =
SHERPA_ONNX_OR(config->model_config.provider, "cpu");
spotter_config.model_config.model_type =
SHERPA_ONNX_OR(config->model_config.model_type, "");
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ set(sources
packed-sequence.cc
pad-sequence.cc
parse-options.cc
provider-config.cc
provider.cc
resample.cc
session.cc
Expand Down
10 changes: 6 additions & 4 deletions sherpa-onnx/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
nemo_ctc.Register(po);
provider_config.Register(po);

po->Register("tokens", &tokens, "Path to tokens.txt");

Expand All @@ -29,9 +30,6 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("debug", &debug,
"true to print model information while loading it.");

po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");

po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
Expand Down Expand Up @@ -87,6 +85,10 @@ bool OnlineModelConfig::Validate() const {
return nemo_ctc.Validate();
}

if (!provider_config.Validate()) {
return false;
}

return transducer.Validate();
}

Expand All @@ -99,11 +101,11 @@ std::string OnlineModelConfig::ToString() const {
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "provider_config=" << provider_config.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "warm_up=" << warm_up << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\", ";
os << "model_type=\"" << model_type << "\", ";
os << "modeling_unit=\"" << modeling_unit << "\", ";
os << "bpe_vocab=\"" << bpe_vocab << "\")";
Expand Down
8 changes: 5 additions & 3 deletions sherpa-onnx/csrc/online-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
#include "sherpa-onnx/csrc/provider-config.h"

namespace sherpa_onnx {

Expand All @@ -20,11 +21,11 @@ struct OnlineModelConfig {
OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNeMoCtcModelConfig nemo_ctc;
ProviderConfig provider_config;
std::string tokens;
int32_t num_threads = 1;
int32_t warm_up = 0;
bool debug = false;
std::string provider = "cpu";

// Valid values:
// - conformer, conformer transducer from icefall
Expand All @@ -50,8 +51,9 @@ struct OnlineModelConfig {
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const OnlineNeMoCtcModelConfig &nemo_ctc,
const ProviderConfig &provider_config,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider,
int32_t warm_up, bool debug,
const std::string &model_type,
const std::string &modeling_unit,
const std::string &bpe_vocab)
Expand All @@ -60,11 +62,11 @@ struct OnlineModelConfig {
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
nemo_ctc(nemo_ctc),
provider_config(provider_config),
tokens(tokens),
num_threads(num_threads),
warm_up(warm_up),
debug(debug),
provider(provider),
manickavela29 marked this conversation as resolved.
Show resolved Hide resolved
model_type(model_type),
modeling_unit(modeling_unit),
bpe_vocab(bpe_vocab) {}
Expand Down
143 changes: 143 additions & 0 deletions sherpa-onnx/csrc/provider-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// sherpa-onnx/csrc/provider-config.cc
//
// Copyright (c) 2024 Uniphore (Author: Manickavela)

#include "sherpa-onnx/csrc/provider-config.h"

#include <sstream>

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

void CudaConfig::Register(ParseOptions *po) {
po->Register("cuda-cudnn-conv-algo-search", &cudnn_conv_algo_search,
"CuDNN convolution algrorithm search");
}

bool CudaConfig::Validate() const {
if (cudnn_conv_algo_search < 1 || cudnn_conv_algo_search > 3) {
SHERPA_ONNX_LOGE("cudnn_conv_algo_search: '%d' is not a valid option."
manickavela29 marked this conversation as resolved.
Show resolved Hide resolved
"Options : [1,3]. Check OnnxRT docs",
cudnn_conv_algo_search);
return false;
}
return true;
}

std::string CudaConfig::ToString() const {
std::ostringstream os;

os << "CudaConfig(";
os << "cudnn_conv_algo_search=" << cudnn_conv_algo_search << ")";

return os.str();
}

void TensorrtConfig::Register(ParseOptions *po) {
po->Register("trt-max-workspace-size", &trt_max_workspace_size,
"Set TensorRT EP GPU memory usage limit.");
po->Register("trt-max-partition-iterations", &trt_max_partition_iterations,
"Limit partitioning iterations for model conversion.");
po->Register("trt-min-subgraph-size", &trt_min_subgraph_size,
"Set minimum size for subgraphs in partitioning.");
po->Register("trt-fp16-enable", &trt_fp16_enable,
"Enable FP16 precision for faster performance.");
po->Register("trt-detailed-build-log", &trt_detailed_build_log,
"Enable detailed logging of build steps.");
po->Register("trt-engine-cache-enable", &trt_engine_cache_enable,
"Enable caching of TensorRT engines.");
po->Register("trt-timing-cache-enable", &trt_timing_cache_enable,
"Enable use of timing cache to speed up builds.");
po->Register("trt-engine-cache-path", &trt_engine_cache_path,
"Set path to store cached TensorRT engines.");
po->Register("trt-timing-cache-path", &trt_timing_cache_path,
"Set path for storing timing cache.");
po->Register("trt-dump-subgraphs", &trt_dump_subgraphs,
"Dump optimized subgraphs for debugging.");
}

bool TensorrtConfig::Validate() const {
if (trt_max_workspace_size < 0) {
manickavela29 marked this conversation as resolved.
Show resolved Hide resolved
SHERPA_ONNX_LOGE("trt_max_workspace_size: %d is not valid.",
trt_max_workspace_size);
return false;
}
if (trt_max_partition_iterations < 0) {
SHERPA_ONNX_LOGE("trt_max_partition_iterations: %d is not valid.",
trt_max_partition_iterations);
return false;
}
if (trt_min_subgraph_size < 0) {
SHERPA_ONNX_LOGE("trt_min_subgraph_size: %d is not valid.",
trt_min_subgraph_size);
return false;
}

return true;
}

std::string TensorrtConfig::ToString() const {
std::ostringstream os;

os << "TensorrtConfig(";
os << "trt_max_workspace_size=" << trt_max_workspace_size << ", ";
os << "trt_max_partition_iterations="
<< trt_max_partition_iterations << ", ";
os << "trt_min_subgraph_size=" << trt_min_subgraph_size << ", ";
os << "trt_fp16_enable=\""
<< (trt_fp16_enable? "True" : "False") << "\", ";
os << "trt_detailed_build_log=\""
<< (trt_detailed_build_log? "True" : "False") << "\", ";
os << "trt_engine_cache_enable=\""
<< (trt_engine_cache_enable? "True" : "False") << "\", ";
os << "trt_engine_cache_path=\""
<< trt_engine_cache_path.c_str() << "\", ";
os << "trt_timing_cache_enable=\""
<< (trt_timing_cache_enable? "True" : "False") << "\", ";
os << "trt_timing_cache_path=\""
<< trt_timing_cache_path.c_str() << "\",";
os << "trt_dump_subgraphs=\""
<< (trt_dump_subgraphs? "True" : "False") << "\" )";
return os.str();
}

void ProviderConfig::Register(ParseOptions *po) {
cuda_config.Register(po);
trt_config.Register(po);

po->Register("device", &device, "GPU device index for CUDA and Trt EP");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}

bool ProviderConfig::Validate() const {
if (device < 0) {
manickavela29 marked this conversation as resolved.
Show resolved Hide resolved
SHERPA_ONNX_LOGE("device: '%d' is invalid.", device);
return false;
}

if (provider == "cuda" && !cuda_config.Validate()) {
return false;
}

if (provider == "trt" && !trt_config.Validate()) {
return false;
}

return true;
}

std::string ProviderConfig::ToString() const {
std::ostringstream os;

os << "ProviderConfig(";
os << "device=" << device << ", ";
os << "provider=\"" << provider << "\", ";
os << "cuda_config=" << cuda_config.ToString() << ", ";
os << "trt_config=" << trt_config.ToString() << ")";
return os.str();
}

} // namespace sherpa_onnx
95 changes: 95 additions & 0 deletions sherpa-onnx/csrc/provider-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// sherpa-onnx/csrc/provider-config.h
//
// Copyright (c) 2024 Uniphore (Author: Manickavela)

#ifndef SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
#define SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_

#include <string>

#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/macros.h"
#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

struct CudaConfig {
int32_t cudnn_conv_algo_search = OrtCudnnConvAlgoSearchHeuristic;

CudaConfig() = default;
explicit CudaConfig(int32_t cudnn_conv_algo_search)
: cudnn_conv_algo_search(cudnn_conv_algo_search) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

struct TensorrtConfig {
int32_t trt_max_workspace_size = 2147483647;
int32_t trt_max_partition_iterations = 10;
int32_t trt_min_subgraph_size = 5;
bool trt_fp16_enable = true;
bool trt_detailed_build_log = false;
bool trt_engine_cache_enable = true;
bool trt_timing_cache_enable = true;
std::string trt_engine_cache_path = ".";
std::string trt_timing_cache_path = ".";
bool trt_dump_subgraphs = false;

TensorrtConfig() = default;
TensorrtConfig(int32_t trt_max_workspace_size,
int32_t trt_max_partition_iterations,
int32_t trt_min_subgraph_size,
bool trt_fp16_enable,
bool trt_detailed_build_log,
bool trt_engine_cache_enable,
bool trt_timing_cache_enable,
const std::string &trt_engine_cache_path,
const std::string &trt_timing_cache_path,
bool trt_dump_subgraphs)
: trt_max_workspace_size(trt_max_workspace_size),
trt_max_partition_iterations(trt_max_partition_iterations),
trt_min_subgraph_size(trt_min_subgraph_size),
trt_fp16_enable(trt_fp16_enable),
trt_detailed_build_log(trt_detailed_build_log),
trt_engine_cache_enable(trt_engine_cache_enable),
trt_timing_cache_enable(trt_timing_cache_enable),
trt_engine_cache_path(trt_engine_cache_path),
trt_timing_cache_path(trt_timing_cache_path),
trt_dump_subgraphs(trt_dump_subgraphs) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

struct ProviderConfig {
TensorrtConfig trt_config;
CudaConfig cuda_config;
std::string provider = "cpu";
int32_t device = 0;
// device only used for cuda and trt

ProviderConfig() = default;
ProviderConfig(const std::string &provider,
int32_t device)
: provider(provider), device(device) {}
ProviderConfig(const TensorrtConfig &trt_config,
const CudaConfig &cuda_config,
const std::string &provider,
int32_t device)
: trt_config(trt_config), cuda_config(cuda_config),
provider(provider), device(device) {}

void Register(ParseOptions *po);
bool Validate() const;

std::string ToString() const;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_PROVIDER_CONFIG_H_
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <string>

#include "sherpa-onnx/csrc/provider-config.h"
namespace sherpa_onnx {

// Please refer to
Expand Down
Loading
Loading