diff --git a/.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml b/.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml new file mode 100644 index 000000000..6ce4372ba --- /dev/null +++ b/.github/workflows/c-api-test-loading-tokens-hotwords-from-memory.yaml @@ -0,0 +1,121 @@ +name: c-api-test-loading-tokens-hotwords-from-memory + +on: + push: + branches: + - master + tags: + - 'v[0-9]+.[0-9]+.[0-9]+*' + paths: + - '.github/workflows/c-api.yaml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + - 'sherpa-onnx/c-api/*' + - 'c-api-examples/**' + - 'ffmpeg-examples/**' + pull_request: + branches: + - master + paths: + - '.github/workflows/c-api.yaml' + - 'CMakeLists.txt' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + - 'sherpa-onnx/c-api/*' + - 'c-api-examples/**' + - 'ffmpeg-examples/**' + + workflow_dispatch: + +concurrency: + group: c-api-${{ github.ref }} + cancel-in-progress: true + +jobs: + c_api: + name: ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ matrix.os }}-c-api-shared + + - name: Build sherpa-onnx + shell: bash + run: | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + cmake --version + + mkdir build + cd build + + cmake \ + -D CMAKE_BUILD_TYPE=Release \ + -D BUILD_SHARED_LIBS=ON \ + -D CMAKE_INSTALL_PREFIX=./install \ + -D SHERPA_ONNX_ENABLE_BINARY=OFF \ + .. + + make -j2 install + + ls -lh install/lib + ls -lh install/include + + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then + ldd ./install/lib/libsherpa-onnx-c-api.so + echo "---" + readelf -d ./install/lib/libsherpa-onnx-c-api.so + fi + + if [[ ${{ matrix.os }} == macos-latest ]]; then + otool -L ./install/lib/libsherpa-onnx-c-api.dylib + fi + + - name: Test streaming zipformer with tokens and hotwords loaded from buffers + shell: bash + run: | + gcc -o streaming-zipformer-buffered-tokens-hotwords-c-api ./c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c \ + -I ./build/install/include \ + -L ./build/install/lib/ \ + -l sherpa-onnx-c-api \ + -l onnxruntime + + ls -lh streaming-zipformer-buffered-tokens-hotwords-c-api + + if [[ ${{ matrix.os }} == ubuntu-latest ]]; then + ldd ./streaming-zipformer-buffered-tokens-hotwords-c-api + echo "----" + readelf -d ./streaming-zipformer-buffered-tokens-hotwords-c-api + fi + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 + curl -SL -O https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small/blob/main/data/lang_bpe_500/bpe.model + cp bpe.model sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/ + rm bpe.model + + printf "▁A ▁T ▁P :1.5\n▁A ▁B ▁C :3.0" > hotwords.txt + + ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17 + echo "---" + ls -lh sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs + + export LD_LIBRARY_PATH=$PWD/build/install/lib:$LD_LIBRARY_PATH + export DYLD_LIBRARY_PATH=$PWD/build/install/lib:$DYLD_LIBRARY_PATH + + ./streaming-zipformer-buffered-tokens-hotwords-c-api + + rm -rf sherpa-onnx-streaming-zipformer-* diff --git a/c-api-examples/CMakeLists.txt b/c-api-examples/CMakeLists.txt index 49fb8fad7..954487f9f 100644 --- a/c-api-examples/CMakeLists.txt +++ b/c-api-examples/CMakeLists.txt @@ -48,6 +48,10 @@ target_link_libraries(telespeech-c-api sherpa-onnx-c-api) add_executable(vad-sense-voice-c-api vad-sense-voice-c-api.c) target_link_libraries(vad-sense-voice-c-api sherpa-onnx-c-api) +add_executable(streaming-zipformer-buffered-tokens-hotwords-c-api + streaming-zipformer-buffered-tokens-hotwords-c-api.c) +target_link_libraries(streaming-zipformer-buffered-tokens-hotwords-c-api sherpa-onnx-c-api) + if(SHERPA_ONNX_HAS_ALSA) add_subdirectory(./asr-microphone-example) elseif((UNIX AND NOT APPLE) OR LINUX) diff --git a/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c new file mode 100644 index 000000000..0da5f3317 --- /dev/null +++ b/c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c @@ -0,0 +1,202 @@ +// c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation +// Copyright (c) 2024 Luo Xiao + +// +// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C +// and with tokens and hotwords loaded from buffered strings instead of from external +// files API. +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// tar xvf sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// rm sherpa-onnx-streaming-zipformer-en-20M-2023-02-17.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +static size_t ReadFile(const char *filename, const char **buffer_out) { + FILE *file = fopen(filename, "rb"); + if (file == NULL) { + fprintf(stderr, "Failed to open %s\n", filename); + return -1; + } + fseek(file, 0L, SEEK_END); + long size = ftell(file); + rewind(file); + *buffer_out = malloc(size); + if (*buffer_out == NULL) { + fclose(file); + fprintf(stderr, "Memory error\n"); + return -1; + } + size_t read_bytes = fread(*buffer_out, 1, size, file); + if (read_bytes != size) { + printf("Errors occured in reading the file %s\n", filename); + free(*buffer_out); + *buffer_out = NULL; + fclose(file); + return -1; + } + fclose(file); + return read_bytes; +} + +int32_t main() { + const char *wav_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav"; + const char *encoder_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "encoder-epoch-99-avg-1.onnx"; + const char *decoder_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "decoder-epoch-99-avg-1.onnx"; + const char *joiner_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "joiner-epoch-99-avg-1.onnx"; + const char *provider = "cpu"; + const char *modeling_unit = "bpe"; + const char *tokens_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt"; + const char *hotwords_filename = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/hotwords.txt"; + const char *bpe_vocab = + "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/" + "bpe.vocab"; + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + // reading tokens and hotwords to buffers + const char *tokens_buf; + size_t token_buf_size = ReadFile(tokens_filename, &tokens_buf); + if (token_buf_size < 1) { + fprintf(stderr, "Please check your tokens.txt!\n"); + free(tokens_buf); + return -1; + } + const char *hotwords_buf; + size_t hotwords_buf_size = ReadFile(hotwords_filename, &hotwords_buf); + if (hotwords_buf_size < 1) { + fprintf(stderr, "Please check your hotwords.txt!\n"); + free(hotwords_buf); + return -1; + } + + // Zipformer config + SherpaOnnxOnlineTransducerModelConfig zipformer_config; + memset(&zipformer_config, 0, sizeof(zipformer_config)); + zipformer_config.encoder = encoder_filename; + zipformer_config.decoder = decoder_filename; + zipformer_config.joiner = joiner_filename; + + // Online model config + SherpaOnnxOnlineModelConfig online_model_config; + memset(&online_model_config, 0, sizeof(online_model_config)); + online_model_config.debug = 1; + online_model_config.num_threads = 1; + online_model_config.provider = provider; + online_model_config.tokens_buf = tokens_buf; + online_model_config.tokens_buf_size = token_buf_size; + online_model_config.transducer = zipformer_config; + + // Recognizer config + SherpaOnnxOnlineRecognizerConfig recognizer_config; + memset(&recognizer_config, 0, sizeof(recognizer_config)); + recognizer_config.decoding_method = "modified_beam_search"; + recognizer_config.model_config = online_model_config; + recognizer_config.hotwords_buf = hotwords_buf; + recognizer_config.hotwords_buf_size = hotwords_buf_size; + + SherpaOnnxOnlineRecognizer *recognizer = + SherpaOnnxCreateOnlineRecognizer(&recognizer_config); + + free(tokens_buf); + tokens_buf = NULL; + free(hotwords_buf); + hotwords_buf = NULL; + + if (recognizer == NULL) { + fprintf(stderr, "Please check your config!\n"); + SherpaOnnxFreeWave(wave); + return -1; + } + + SherpaOnnxOnlineStream *stream = SherpaOnnxCreateOnlineStream(recognizer); + + const SherpaOnnxDisplay *display = SherpaOnnxCreateDisplay(50); + int32_t segment_id = 0; + +// simulate streaming. You can choose an arbitrary N +#define N 3200 + + fprintf(stderr, "sample rate: %d, num samples: %d, duration: %.2f s\n", + wave->sample_rate, wave->num_samples, + (float)wave->num_samples / wave->sample_rate); + + int32_t k = 0; + while (k < wave->num_samples) { + int32_t start = k; + int32_t end = + (start + N > wave->num_samples) ? wave->num_samples : (start + N); + k += N; + + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, + wave->samples + start, end - start); + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { + SherpaOnnxDecodeOnlineStream(recognizer, stream); + } + + const SherpaOnnxOnlineRecognizerResult *r = + SherpaOnnxGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaOnnxPrint(display, segment_id, r->text); + } + + if (SherpaOnnxOnlineStreamIsEndpoint(recognizer, stream)) { + if (strlen(r->text)) { + ++segment_id; + } + SherpaOnnxOnlineStreamReset(recognizer, stream); + } + + SherpaOnnxDestroyOnlineRecognizerResult(r); + } + + // add some tail padding + float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate + SherpaOnnxOnlineStreamAcceptWaveform(stream, wave->sample_rate, tail_paddings, + 4800); + + SherpaOnnxFreeWave(wave); + + SherpaOnnxOnlineStreamInputFinished(stream); + while (SherpaOnnxIsOnlineStreamReady(recognizer, stream)) { + SherpaOnnxDecodeOnlineStream(recognizer, stream); + } + + const SherpaOnnxOnlineRecognizerResult *r = + SherpaOnnxGetOnlineStreamResult(recognizer, stream); + + if (strlen(r->text)) { + SherpaOnnxPrint(display, segment_id, r->text); + } + + SherpaOnnxDestroyOnlineRecognizerResult(r); + + SherpaOnnxDestroyDisplay(display); + SherpaOnnxDestroyOnlineStream(stream); + SherpaOnnxDestroyOnlineRecognizer(recognizer); + fprintf(stderr, "\n"); + + return 0; +} \ No newline at end of file diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index cdee9a209..0a4c683ee 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -73,6 +73,12 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( recognizer_config.model_config.tokens = SHERPA_ONNX_OR(config->model_config.tokens, ""); + if (config->model_config.tokens_buf && + config->model_config.tokens_buf_size > 0) { + recognizer_config.model_config.tokens_buf = std::string( + config->model_config.tokens_buf, config->model_config.tokens_buf_size); + } + recognizer_config.model_config.num_threads = SHERPA_ONNX_OR(config->model_config.num_threads, 1); recognizer_config.model_config.provider_config.provider = @@ -120,6 +126,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + if (config->hotwords_buf && config->hotwords_buf_size > 0) { + recognizer_config.hotwords_buf = + std::string(config->hotwords_buf, config->hotwords_buf_size); + } recognizer_config.blank_penalty = config->blank_penalty; diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index d4844aed1..11dba9816 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -88,6 +88,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig { // - cjkchar+bpe const char *modeling_unit; const char *bpe_vocab; + /// if non-null, loading the tokens from the buffered string directly in + /// prioriy + const char *tokens_buf; + /// byte size excluding the tailing '\0' + int32_t tokens_buf_size; } SherpaOnnxOnlineModelConfig; /// It expects 16 kHz 16-bit single channel wave format. @@ -147,6 +152,11 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { const char *rule_fsts; const char *rule_fars; float blank_penalty; + + /// if non-nullptr, loading the hotwords from the buffered string directly in + const char *hotwords_buf; + /// byte size excluding the tailing '\0' + int32_t hotwords_buf_size; } SherpaOnnxOnlineRecognizerConfig; SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult { diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 9913fa9ed..5592c8d0a 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -56,8 +56,19 @@ bool OnlineModelConfig::Validate() const { return false; } - if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str()); + if (!tokens_buf.empty() && FileExists(tokens)) { + SHERPA_ONNX_LOGE( + "you can not provide a tokens_buf and a tokens file: '%s', " + "at the same time, which is confusing", + tokens.c_str()); + return false; + } + + if (tokens_buf.empty() && !FileExists(tokens)) { + SHERPA_ONNX_LOGE( + "tokens: '%s' does not exist, you should provide " + "either a tokens buffer or a tokens file", + tokens.c_str()); return false; } diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 0b64e06de..a2aaae038 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -45,6 +45,11 @@ struct OnlineModelConfig { std::string modeling_unit = "cjkchar"; std::string bpe_vocab; + /// if tokens_buf is non-empty, + /// the tokens will be loaded from the buffered string instead of from the + /// ${tokens} file + std::string tokens_buf; + OnlineModelConfig() = default; OnlineModelConfig(const OnlineTransducerModelConfig &transducer, const OnlineParaformerModelConfig ¶former, @@ -53,8 +58,7 @@ struct OnlineModelConfig { const OnlineNeMoCtcModelConfig &nemo_ctc, const ProviderConfig &provider_config, const std::string &tokens, int32_t num_threads, - int32_t warm_up, bool debug, - const std::string &model_type, + int32_t warm_up, bool debug, const std::string &model_type, const std::string &modeling_unit, const std::string &bpe_vocab) : transducer(transducer), diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index ab1e165f3..0bdb1cca4 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -83,8 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { : OnlineRecognizerImpl(config), config_(config), model_(OnlineTransducerModel::Create(config.model_config)), - sym_(config.model_config.tokens), endpoint_(config_.endpoint_config) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + if (sym_.Contains("")) { unk_id_ = sym_[""]; } @@ -97,7 +103,9 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { config_.model_config.bpe_vocab); } - if (!config_.hotwords_file.empty()) { + if (!config_.hotwords_buf.empty()) { + InitHotwordsFromBufStr(); + } else if (!config_.hotwords_file.empty()) { InitHotwords(); } @@ -108,8 +116,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, - config_.blank_penalty, - config_.temperature_scale); + config_.blank_penalty, config_.temperature_scale); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( @@ -158,8 +165,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { decoder_ = std::make_unique( model_.get(), lm_.get(), config_.max_active_paths, config_.lm_config.scale, config_.lm_config.shallow_fusion, unk_id_, - config_.blank_penalty, - config_.temperature_scale); + config_.blank_penalty, config_.temperature_scale); } else if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( @@ -437,6 +443,20 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } #endif + void InitHotwordsFromBufStr() { + // each line in hotwords_file contains space-separated words + + std::istringstream iss(config_.hotwords_buf); + if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + void InitOnlineStream(OnlineStream *stream) const { auto r = decoder_->GetEmptyResult(); diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h index 700054dc2..0a09fdb01 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h @@ -44,10 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl { const OnlineRecognizerConfig &config) : OnlineRecognizerImpl(config), config_(config), - symbol_table_(config.model_config.tokens), endpoint_(config_.endpoint_config), model_( std::make_unique(config.model_config)) { + if (!config.model_config.tokens_buf.empty()) { + symbol_table_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + symbol_table_ = SymbolTable(config.model_config.tokens, true); + } + if (config.decoding_method == "greedy_search") { decoder_ = std::make_unique( model_.get(), config_.blank_penalty); diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 7fde367fb..eedd30b21 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -106,6 +106,11 @@ struct OnlineRecognizerConfig { // If there are multiple FST archives, they are applied from left to right. std::string rule_fars; + /// used only for modified_beam_search, if hotwords_buf is non-empty, + /// the hotwords will be loaded from the buffered string instead of from + /// ${hotwords_file} + std::string hotwords_buf; + OnlineRecognizerConfig() = default; OnlineRecognizerConfig( diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 8862972b7..5655c03a8 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -20,9 +20,14 @@ namespace sherpa_onnx { -SymbolTable::SymbolTable(const std::string &filename) { - std::ifstream is(filename); - Init(is); +SymbolTable::SymbolTable(const std::string &filename, bool is_file) { + if (is_file) { + std::ifstream is(filename); + Init(is); + } else { + std::istringstream iss(filename); + Init(iss); + } } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 00d7a69e2..2c17b4d5e 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -19,13 +19,13 @@ namespace sherpa_onnx { class SymbolTable { public: SymbolTable() = default; - /// Construct a symbol table from a file. + /// Construct a symbol table from a file or from a buffered string. /// Each line in the file contains two fields: /// /// sym ID /// /// Fields are separated by space(s). - explicit SymbolTable(const std::string &filename); + explicit SymbolTable(const std::string &filename, bool is_file = true); #if __ANDROID_API__ >= 9 SymbolTable(AAssetManager *mgr, const std::string &filename);