Skip to content

Commit

Permalink
1. added token_buf_size and hotwords_buf_size to avoid memory overflow
Browse files Browse the repository at this point in the history
2. rewrite some code to make it more readable
3. updated the c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c
  • Loading branch information
xiao committed Sep 10, 2024
1 parent 5b3e6c4 commit 07a8cd7
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// c-api-examples/streaming-zipformer-c-api.c
// c-api-examples/streaming-zipformer-buffered-tokens-hotwords-c-api.c
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Luo Xiao

//
// This file demonstrates how to use streaming Zipformer with sherpa-onnx's C and
Expand All @@ -20,8 +21,33 @@

#include "sherpa-onnx/c-api/c-api.h"

extern const char* tokens_buf_str;
extern const char* hotwords_buf_str;
size_t read_file(const char* filename, const char** buffer_out) {
FILE *file = fopen(filename, "rb");
if (file == NULL) {
fprintf(stderr, "Failed to open %s\n", filename);
return -1;
}
fseek(file, 0L, SEEK_END);
long size = ftell(file);
rewind(file);
*buffer_out = malloc(size + 1);
if (*buffer_out == NULL) {
fclose(file);
fprintf(stderr, "Memory error\n");
return -1;
}
size_t read_bytes = fread(*buffer_out, 1, size, file);
if (read_bytes != size) {
printf("Errors occured in reading the file %s\n", filename);
free(*buffer_out);
*buffer_out = NULL;
fclose(file);
return -1;
}
fclose(file);
return read_bytes;
}

int32_t main() {
const char *wav_filename =
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/test_wavs/0.wav";
Expand All @@ -36,6 +62,8 @@ int32_t main() {
"joiner-epoch-99-avg-1.onnx";
const char *provider = "cpu";
const char *modeling_unit = "bpe";
const char *tokens_filename = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/tokens.txt";
const char *hotwords_filename = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/hotwords.txt";
const char *bpe_vocab =
"sherpa-onnx-streaming-zipformer-en-20M-2023-02-17/"
"bpe.vocab";
Expand All @@ -45,6 +73,22 @@ int32_t main() {
return -1;
}

// reading tokens and hotwords to buffers
const *tokens_buf;
size_t token_buf_size = read_file(tokens_filename, &tokens_buf);
if(token_buf_size < 0) {
fprintf(stderr, "Please check your tokens.txt!\n");
free(tokens_buf);
return -1;
}
const *hotwords_buf;
size_t hotwords_buf_size = read_file(hotwords_filename, &hotwords_buf);
if(hotwords_buf_size < 0) {
fprintf(stderr, "Please check your hotwords.txt!\n");
free(hotwords_buf);
return -1;
}

// Zipformer config
SherpaOnnxOnlineTransducerModelConfig zipformer_config;
memset(&zipformer_config, 0, sizeof(zipformer_config));
Expand All @@ -58,15 +102,17 @@ int32_t main() {
online_model_config.debug = 1;
online_model_config.num_threads = 1;
online_model_config.provider = provider;
online_model_config.tokens_buf_str = tokens_buf_str;
online_model_config.tokens_buf = tokens_buf;
online_model_config.tokens_buf_size = token_buf_size;
online_model_config.transducer = zipformer_config;

// Recognizer config
SherpaOnnxOnlineRecognizerConfig recognizer_config;
memset(&recognizer_config, 0, sizeof(recognizer_config));
recognizer_config.decoding_method = "modified_beam_search";
recognizer_config.model_config = online_model_config;
recognizer_config.hotwords_buf_str = hotwords_buf_str;
recognizer_config.hotwords_buf = hotwords_buf;
recognizer_config.hotwords_buf_size = hotwords_buf_size;

SherpaOnnxOnlineRecognizer *recognizer =
SherpaOnnxCreateOnlineRecognizer(&recognizer_config);
Expand Down Expand Up @@ -148,10 +194,10 @@ int32_t main() {
return 0;
}

const char* hotwords_buf_str = "▁A ▁T ▁P :1.5\n \
const char* hotwords_buf = "▁A ▁T ▁P :1.5\n \
▁A ▁B ▁C :3.0";

const char* tokens_buf_str = "<blk> 0 \n \
const char* tokens_buf = "<blk> 0 \n \
<sos/eos> 1 \n \
<unk> 2 \n \
S 3 \n \
Expand Down
19 changes: 15 additions & 4 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(

recognizer_config.model_config.tokens =
SHERPA_ONNX_OR(config->model_config.tokens, "");
recognizer_config.model_config.tokens_buf_str =
SHERPA_ONNX_OR(config->model_config.tokens_buf_str, "");
if (config->model_config.tokens_buf &&
config->model_config.tokens_buf_size > 0) {
recognizer_config.model_config.tokens_buf = std::string(config->model_config.tokens_buf,
config->model_config.tokens_buf_size);
} else {
recognizer_config.model_config.tokens_buf = "";
}

recognizer_config.model_config.num_threads =
SHERPA_ONNX_OR(config->model_config.num_threads, 1);
recognizer_config.model_config.provider_config.provider =
Expand Down Expand Up @@ -120,10 +126,15 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
SHERPA_ONNX_OR(config->rule3_min_utterance_length, 20);

recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_buf_str =
SHERPA_ONNX_OR(config->hotwords_buf_str, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
if (config->hotwords_buf &&
config->hotwords_buf_size > 0) {
recognizer_config.hotwords_buf = std::string(config->hotwords_buf,
config->hotwords_buf_size);
} else {
recognizer_config.hotwords_buf = "";
}

recognizer_config.blank_penalty = config->blank_penalty;

Expand Down
8 changes: 6 additions & 2 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineModelConfig {
const char *modeling_unit;
const char *bpe_vocab;
/// if non-null, loading the tokens from the buffered string directly in prioriy
const char *tokens_buf_str;
const char *tokens_buf;
/// byte size excluding the tailing '\0'
int32_t tokens_buf_size;
} SherpaOnnxOnlineModelConfig;

/// It expects 16 kHz 16-bit single channel wave format.
Expand Down Expand Up @@ -151,7 +153,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
float blank_penalty;

/// if non-nullptr, loading the hotwords from the buffered string directly in
const char *hotwords_buf_str;
const char *hotwords_buf;
/// byte size excluding the tailing '\0'
int32_t hotwords_buf_size;
} SherpaOnnxOnlineRecognizerConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerResult {
Expand Down
11 changes: 9 additions & 2 deletions sherpa-onnx/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,15 @@ bool OnlineModelConfig::Validate() const {
return false;
}

if (tokens_buf_str.empty() && !FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: '%s' does not exist", tokens.c_str());
if (!tokens_buf.empty() && FileExists(tokens)) {
SHERPA_ONNX_LOGE("you can not provide a tokens_buf and a tokens file: '%s', "
"at the same, which is confusing", tokens.c_str());
return false;
}

if (tokens_buf.empty() && !FileExists(tokens)) {
SHERPA_ONNX_LOGE("tokens: '%s' does not exist, you should provide "
"either a tokens buffer or a tokens file" , tokens.c_str());
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-model-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct OnlineModelConfig {
std::string modeling_unit = "cjkchar";
std::string bpe_vocab;

/// if tokens_buf_str is non-empty,
/// if tokens_buf is non-empty,
/// the tokens will be loaded from the buffered string in prior to the ${tokens} file
std::string tokens_buf;

Expand Down
13 changes: 9 additions & 4 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,14 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
: OnlineRecognizerImpl(config),
config_(config),
model_(OnlineTransducerModel::Create(config.model_config)),
sym_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens :
config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? true : false),
endpoint_(config_.endpoint_config) {
if(!config.model_config.tokens_buf.empty()) {
sym_ = std::move(SymbolTable(config.model_config.tokens_buf, false));
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
sym_ = std::move(SymbolTable(config.model_config.tokens, true));
}

if (sym_.Contains("<unk>")) {
unk_id_ = sym_["<unk>"];
}
Expand All @@ -98,7 +103,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
config_.model_config.bpe_vocab);
}

if(!config_.hotwords_buf_str.empty()) {
if(!config_.hotwords_buf.empty()) {
InitHotwordsFromBufStr();
} else if (!config_.hotwords_file.empty()) {
InitHotwords();
Expand Down Expand Up @@ -443,7 +448,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
void InitHotwordsFromBufStr() {
// each line in hotwords_file contains space-separated words

std::istringstream iss(config_.hotwords_buf_str);
std::istringstream iss(config_.hotwords_buf);
if (!EncodeHotwords(iss, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
Expand Down
9 changes: 7 additions & 2 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
const OnlineRecognizerConfig &config)
: OnlineRecognizerImpl(config),
config_(config),
symbol_table_(config.model_config.tokens_buf_str.empty() ? config.model_config.tokens :
config.model_config.tokens_buf_str, config.model_config.tokens_buf_str.empty() ? true : false),
endpoint_(config_.endpoint_config),
model_(
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
if(!config.model_config.tokens_buf.empty()) {
symbol_table_ = std::move(SymbolTable(config.model_config.tokens_buf, false));
} else {
/// assuming tokens_buf and tokens are guaranteed not being both empty
symbol_table_ = std::move(SymbolTable(config.model_config.tokens, true));
}

if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
Expand Down
7 changes: 3 additions & 4 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ struct OnlineRecognizerConfig {
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;

/// used only for modified_beam_search, if hotwords_buf_str is non-empty,
/// used only for modified_beam_search, if hotwords_buf is non-empty,
/// the hotwords will be loaded from the buffered string in prior to the ${hotwords_file}
std::string hotwords_buf_str;
std::string hotwords_buf;

OnlineRecognizerConfig() = default;

Expand All @@ -134,8 +134,7 @@ struct OnlineRecognizerConfig {
blank_penalty(blank_penalty),
temperature_scale(temperature_scale),
rule_fsts(rule_fsts),
rule_fars(rule_fars),
hotwords_buf_str("") {}
rule_fars(rule_fars) {}

void Register(ParseOptions *po);
bool Validate() const;
Expand Down

0 comments on commit 07a8cd7

Please sign in to comment.