From 30a90f968fb6a01ceebe9fe4434a157858e05f3a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 Mar 2024 20:33:17 +0800 Subject: [PATCH] Add SHL provider. --- CMakeLists.txt | 6 ++++++ build-riscv64-linux-gnu.sh | 1 + sherpa-onnx/csrc/provider.cc | 2 ++ sherpa-onnx/csrc/provider.h | 1 + sherpa-onnx/csrc/session.cc | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 42 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0d98b00cb..2b74e10ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ option(SHERPA_ONNX_ENABLE_WASM_NODEJS "Whether to enable WASM for NodeJS" OFF) option(SHERPA_ONNX_ENABLE_BINARY "Whether to build binaries" ON) option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. Used only when BUILD_SHARED_LIBS is OFF on Linux" ON) option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON) +option(SHERPA_ONNX_ENABLE_SHL "Whether to enable SHL for RISC-V" OFF) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") @@ -112,6 +113,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_WASM_TTS ${SHERPA_ONNX_ENABLE_WASM_TTS}") message(STATUS "SHERPA_ONNX_ENABLE_WASM_ASR ${SHERPA_ONNX_ENABLE_WASM_ASR}") message(STATUS "SHERPA_ONNX_ENABLE_WASM_NODEJS ${SHERPA_ONNX_ENABLE_WASM_NODEJS}") message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}") +message(STATUS "SHERPA_ONNX_ENABLE_SHL ${SHERPA_ONNX_ENABLE_SHL}") if(SHERPA_ONNX_ENABLE_WASM_TTS) if(NOT SHERPA_ONNX_ENABLE_WASM) @@ -135,6 +137,10 @@ if(SHERPA_ONNX_ENABLE_WASM) add_definitions(-DSHERPA_ONNX_ENABLE_WASM=1) endif() +if(SHERPA_ONNX_ENABLE_SHL) + add_definitions(-DSHERPA_ONNX_ENABLE_SHL=1) +endif() + if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") endif() diff --git a/build-riscv64-linux-gnu.sh b/build-riscv64-linux-gnu.sh index 16d9c0d6a..9c7cfc958 100755 --- a/build-riscv64-linux-gnu.sh +++ b/build-riscv64-linux-gnu.sh @@ -62,6 +62,7 @@ cmake \ -DSHERPA_ONNX_ENABLE_JNI=OFF \ -DSHERPA_ONNX_ENABLE_C_API=OFF \ -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ + -DSHERPA_ONNX_ENABLE_SHL=ON \ -DCMAKE_TOOLCHAIN_FILE=../toolchains/riscv64-linux-gnu.toolchain.cmake \ .. diff --git a/sherpa-onnx/csrc/provider.cc b/sherpa-onnx/csrc/provider.cc index 95bc18c5f..b6cc0f2cb 100644 --- a/sherpa-onnx/csrc/provider.cc +++ b/sherpa-onnx/csrc/provider.cc @@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) { return Provider::kXnnpack; } else if (s == "nnapi") { return Provider::kNNAPI; + } else if (s == "shl") { + return Provider::kShl; } else { SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str()); return Provider::kCPU; diff --git a/sherpa-onnx/csrc/provider.h b/sherpa-onnx/csrc/provider.h index 467e5dab5..0c926ac72 100644 --- a/sherpa-onnx/csrc/provider.h +++ b/sherpa-onnx/csrc/provider.h @@ -18,6 +18,7 @@ enum class Provider { kCoreML = 2, // CoreMLExecutionProvider kXnnpack = 3, // XnnpackExecutionProvider kNNAPI = 4, // NnapiExecutionProvider + kShl = 5, // kShlExecutionProvider }; /** diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 94987ebc9..591f45d4e 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -9,6 +9,7 @@ #include #include +#include "onnxruntime_c_api.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/provider.h" #if defined(__APPLE__) @@ -115,6 +116,37 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, #endif break; } + + case Provider::kShl: { +#if defined(SHERPA_ONNX_ENABLE_SHL) + if (std::find(available_providers.begin(), available_providers.end(), + "ShlExecutionProvider") != available_providers.end()) { + // sess_opts.AppendExecutionProvider_SHL({}); + + const auto &api = Ort::GetApi(); + OrtStatus *status = + api.OrtSessionOptionsAppendExecutionProvider_Shl(sess_opts, ""); + + if (status) { + const char *msg = api.GetErrorMessage(status); + SHERPA_ONNX_LOGE( + "Failed to enable Shl: %s. Available providers: %s. Fallback " + "to " + "cpu", + msg, os.str().c_str()); + api.ReleaseStatus(status); + } else { + SHERPA_ONNX_LOGE("Use Shl"); + } + } else +#endif + { + SHERPA_ONNX_LOGE("Available providers: %s. Fallback to cpu!", + os.str().c_str()); + } + + break; + } } return sess_opts;