From 12e1d4cda426462fccd8a739cd097abf15100bf7 Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 21 Jun 2024 16:35:54 +0800 Subject: [PATCH 1/2] Add tokenize-hotwords option --- .../csrc/offline-recognizer-transducer-impl.h | 9 ++++-- sherpa-onnx/csrc/offline-recognizer.cc | 6 ++++ sherpa-onnx/csrc/offline-recognizer.h | 13 +++++--- .../csrc/online-recognizer-transducer-impl.h | 9 ++++-- sherpa-onnx/csrc/online-recognizer.cc | 5 ++++ sherpa-onnx/csrc/online-recognizer.h | 16 ++++++---- sherpa-onnx/csrc/text2token-test.cc | 25 +++++++++++----- sherpa-onnx/csrc/utils.cc | 13 +++++++- sherpa-onnx/csrc/utils.h | 2 +- sherpa-onnx/python/csrc/offline-recognizer.cc | 8 +++-- sherpa-onnx/python/csrc/online-recognizer.cc | 30 ++++++++++--------- 11 files changed, 95 insertions(+), 41 deletions(-) diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index c439319eb..b8b722190 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -157,7 +157,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::vector> current; std::vector current_scores; if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), ¤t, ¤t_scores)) { + config_.tokenize_hotwords, bpe_encoder_.get(), ¤t, + ¤t_scores)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } @@ -257,7 +258,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); @@ -281,7 +283,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 1285a5cd3..87c37ecd8 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -45,6 +45,11 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); + po->Register( + "tokenize-hotwords", &tokenize_hotwords, + "Whether to tokenize hotwords, default true, if false the input hotwords " + "should be tokenized into tokens"); + po->Register( "rule-fsts", &rule_fsts, "If not empty, it specifies fsts for inverse text normalization. " @@ -125,6 +130,7 @@ std::string OfflineRecognizerConfig::ToString() const { os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; os << "hotwords_score=" << hotwords_score << ", "; + os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", "; os << "blank_penalty=" << blank_penalty << ", "; os << "rule_fsts=\"" << rule_fsts << "\", "; os << "rule_fars=\"" << rule_fars << "\")"; diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 9290a53b5..b0f4711aa 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -37,6 +37,9 @@ struct OfflineRecognizerConfig { std::string hotwords_file; float hotwords_score = 1.5; + /// Whether to tokenize the input hotwords, normally should be true + /// if false, you have to tokenize hotwords by yourself. + bool tokenize_hotwords = true; float blank_penalty = 0.0; @@ -56,7 +59,7 @@ struct OfflineRecognizerConfig { const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, - float blank_penalty, const std::string &rule_fsts, + bool tokenize_hotwords, float blank_penalty, const std::string &rule_fsts, const std::string &rule_fars) : feat_config(feat_config), model_config(model_config), @@ -66,6 +69,7 @@ struct OfflineRecognizerConfig { max_active_paths(max_active_paths), hotwords_file(hotwords_file), hotwords_score(hotwords_score), + tokenize_hotwords(tokenize_hotwords), blank_penalty(blank_penalty), rule_fsts(rule_fsts), rule_fars(rule_fars) {} @@ -94,9 +98,10 @@ class OfflineRecognizer { /** Create a stream for decoding. * * @param The hotwords for this string, it might contain several hotwords, - * the hotwords are separated by "/". In each of the hotwords, there - * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). - * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO + * WORLD. if tokenize_hotwords is false, the hotwords should be + * tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look + * like: * * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" */ diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 2bea765cb..219a65d75 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -186,7 +186,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { std::vector> current; std::vector current_scores; if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), ¤t, ¤t_scores)) { + config_.tokenize_hotwords, bpe_encoder_.get(), ¤t, + ¤t_scores)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } @@ -401,7 +402,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); @@ -425,7 +427,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 599a0553d..232749cf3 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -96,6 +96,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "The file containing hotwords, one words/phrases per line, For example: " "HELLO WORLD" "你好世界"); + po->Register( + "tokenize-hotwords", &tokenize_hotwords, + "Whether to tokenize hotwords, default true, if false the input hotwords " + "should be tokenized into tokens"); po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); @@ -183,6 +187,7 @@ std::string OnlineRecognizerConfig::ToString() const { os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_score=" << hotwords_score << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; + os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", "; os << "decoding_method=\"" << decoding_method << "\", "; os << "blank_penalty=" << blank_penalty << ", "; os << "temperature_scale=" << temperature_scale << ", "; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 7fde367fb..150e613b9 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -95,6 +95,9 @@ struct OnlineRecognizerConfig { /// used only for modified_beam_search std::string hotwords_file; float hotwords_score = 1.5; + /// Whether to tokenize the input hotwords, normally should be true + /// if false, you have to tokenize hotwords by yourself. + bool tokenize_hotwords = true; float blank_penalty = 0.0; @@ -115,8 +118,9 @@ struct OnlineRecognizerConfig { const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, bool enable_endpoint, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, - float hotwords_score, float blank_penalty, float temperature_scale, - const std::string &rule_fsts, const std::string &rule_fars) + float hotwords_score, bool tokenize_hotwords, float blank_penalty, + float temperature_scale, const std::string &rule_fsts, + const std::string &rule_fars) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -127,6 +131,7 @@ struct OnlineRecognizerConfig { max_active_paths(max_active_paths), hotwords_file(hotwords_file), hotwords_score(hotwords_score), + tokenize_hotwords(tokenize_hotwords), blank_penalty(blank_penalty), temperature_scale(temperature_scale), rule_fsts(rule_fsts), @@ -156,9 +161,10 @@ class OnlineRecognizer { /** Create a stream for decoding. * * @param The hotwords for this string, it might contain several hotwords, - * the hotwords are separated by "/". In each of the hotwords, there - * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). - * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO + * WORLD. if tokenize_hotwords is false, the hotwords should be + * tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look + * like: * * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" */ diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc index 0ad912df8..e1a9f091e 100644 --- a/sherpa-onnx/csrc/text2token-test.cc +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -43,12 +43,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar) { std::vector> ids; std::vector scores; - auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores); + auto r = + EncodeHotwords(iss, "cjkchar", sym_table, true, nullptr, &ids, &scores); std::vector> expected_ids( {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); EXPECT_EQ(ids, expected_ids); + EXPECT_EQ(scores.size(), 0); + + // tokenize_hotwords = false + text = "世 界 人 民 大 团 结\n中 国 V S 美 国\n\n"; // Test blank lines also + iss.clear(); + iss.str(text); + + r = EncodeHotwords(iss, "cjkchar", sym_table, false, nullptr, &ids, &scores); + + EXPECT_EQ(ids, expected_ids); EXPECT_EQ(scores.size(), 0); } @@ -79,8 +90,8 @@ TEST(TEXT2TOKEN, TEST_bpe) { std::vector> ids; std::vector scores; - auto r = - EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); + auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(), + &ids, &scores); std::vector> expected_ids( {{22, 58, 24, 425}, {19, 370, 47}}); @@ -117,8 +128,8 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { std::vector> ids; std::vector scores; - auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), - &ids, &scores); + auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, true, + bpe_processor.get(), &ids, &scores); std::vector> expected_ids( {{1368, 1392, 557, 680, 275, 178, 475}, @@ -156,8 +167,8 @@ TEST(TEXT2TOKEN, TEST_bbpe) { std::vector> ids; std::vector scores; - auto r = - EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); + auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(), + &ids, &scores); std::vector> expected_ids( {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index b5df9682f..71a095aba 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -70,6 +70,9 @@ static bool EncodeBase(const std::vector &lines, } } } + if (tmp_ids.empty()) { + continue; + } ids->push_back(std::move(tmp_ids)); tmp_ids = {}; tmp_scores.push_back(score); @@ -101,7 +104,7 @@ static bool EncodeBase(const std::vector &lines, } bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, - const SymbolTable &symbol_table, + const SymbolTable &symbol_table, bool tokenize_hotwords, const ssentencepiece::Ssentencepiece *bpe_encoder, std::vector> *hotwords, std::vector *boost_scores) { @@ -109,6 +112,14 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, std::string line; std::string word; + if (!tokenize_hotwords) { + while (std::getline(is, line)) { + lines.push_back(line); + } + return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores, + nullptr); + } + while (std::getline(is, line)) { std::string score; std::string phrase; diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index a9d59e8a2..6a280e670 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -27,7 +27,7 @@ namespace sherpa_onnx { * otherwise returns false. */ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, - const SymbolTable &symbol_table, + const SymbolTable &symbol_table, bool tokenize_hotwords, const ssentencepiece::Ssentencepiece *bpe_encoder_, std::vector> *hotwords_id, std::vector *boost_scores); diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 2a603e08f..011d8a077 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -17,14 +17,15 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def(py::init(), + bool, float, const std::string &, const std::string &>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, - py::arg("rule_fsts") = "", py::arg("rule_fars") = "") + py::arg("hotwords_score") = 1.5, py::arg("tokenize_hotwords") = true, + py::arg("blank_penalty") = 0.0, py::arg("rule_fsts") = "", + py::arg("rule_fars") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -33,6 +34,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords) .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("rule_fsts", &PyClass::rule_fsts) .def_readwrite("rule_fars", &PyClass::rule_fars) diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index fe6cd454a..1e587f007 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -54,20 +54,21 @@ static void PybindOnlineRecognizerResult(py::module *m) { static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") - .def(py::init(), - py::arg("feat_config"), py::arg("model_config"), - py::arg("lm_config") = OnlineLMConfig(), - py::arg("endpoint_config") = EndpointConfig(), - py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, - py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", - py::arg("rule_fars") = "") + .def( + py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("lm_config") = OnlineLMConfig(), + py::arg("endpoint_config") = EndpointConfig(), + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("tokenize_hotwords") = true, + py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0, + py::arg("rule_fsts") = "", py::arg("rule_fars") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -78,6 +79,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords) .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("temperature_scale", &PyClass::temperature_scale) .def_readwrite("rule_fsts", &PyClass::rule_fsts) From abecc1ce94e251622c45eb70280aa6e562a12b1b Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 21 Jun 2024 17:07:54 +0800 Subject: [PATCH 2/2] Add c-api, python api, jni --- sherpa-onnx/c-api/c-api.cc | 4 ++++ sherpa-onnx/c-api/c-api.h | 7 +++++++ sherpa-onnx/jni/offline-recognizer.cc | 3 +++ sherpa-onnx/jni/online-recognizer.cc | 3 +++ sherpa-onnx/kotlin-api/OfflineRecognizer.kt | 1 + sherpa-onnx/kotlin-api/OnlineRecognizer.kt | 1 + sherpa-onnx/python/sherpa_onnx/offline_recognizer.py | 5 +++++ sherpa-onnx/python/sherpa_onnx/online_recognizer.py | 5 +++++ 8 files changed, 29 insertions(+) diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 2d0118833..ba6f9d162 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -104,6 +104,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + recognizer_config.tokenize_hotwords = + SHERPA_ONNX_OR(config->tokenize_hotwords, true); recognizer_config.ctc_fst_decoder_config.graph = SHERPA_ONNX_OR(config->ctc_fst_decoder_config.graph, ""); @@ -390,6 +392,8 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + recognizer_config.tokenize_hotwords = + SHERPA_ONNX_OR(config->tokenize_hotwords, true); recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, ""); recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, ""); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index e9637ae7c..14fa03770 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -143,6 +143,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { /// Bonus score for each token in hotwords. float hotwords_score; + /// Whether to tokenize hotwords + bool tokenize_hotwords; + SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config; const char *rule_fsts; const char *rule_fars; @@ -413,6 +416,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { /// Bonus score for each token in hotwords. float hotwords_score; + + /// Whether to tokenize hotwords + bool tokenize_hotwords; + const char *rule_fsts; const char *rule_fars; } SherpaOnnxOfflineRecognizerConfig; diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 070d46f08..21a252bfd 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -34,6 +34,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "hotwordsScore", "F"); ans.hotwords_score = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "tokenizeHotwords", "Z"); + ans.tokenize_hotwords = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index d8acd0fed..03d92e48b 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -37,6 +37,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "hotwordsScore", "F"); ans.hotwords_score = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "tokenizeHotwords", "Z"); + ans.tokenize_hotwords = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 7163d3d10..058873259 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -53,6 +53,7 @@ data class OfflineRecognizerConfig( var maxActivePaths: Int = 4, var hotwordsFile: String = "", var hotwordsScore: Float = 1.5f, + var tokenizeHotwords: Boolean = true, var ruleFsts: String = "", var ruleFars: String = "", ) diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index de47a5ebd..5a6f1b214 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -69,6 +69,7 @@ data class OnlineRecognizerConfig( var maxActivePaths: Int = 4, var hotwordsFile: String = "", var hotwordsScore: Float = 1.5f, + var tokenizeHotwords: Boolean = true, var ruleFsts: String = "", var ruleFars: String = "", ) diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index f0e9a45f2..1f38e056f 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -49,6 +49,7 @@ def from_transducer( max_active_paths: int = 4, hotwords_file: str = "", hotwords_score: float = 1.5, + tokenize_hotwords: bool = True, blank_penalty: float = 0.0, modeling_unit: str = "cjkchar", bpe_vocab: str = "", @@ -96,6 +97,9 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + tokenize_hotwords: + Whether to tokenize hotwords, true will tokenize hotwords in the engine + if false, you have to tokenize hotwords by yourself. blank_penalty: The penalty applied on blank symbol during decoding. modeling_unit: @@ -165,6 +169,7 @@ def from_transducer( max_active_paths=max_active_paths, hotwords_file=hotwords_file, hotwords_score=hotwords_score, + tokenize_hotwords=tokenize_hotwords, blank_penalty=blank_penalty, rule_fsts=rule_fsts, rule_fars=rule_fars, diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 82b2e3b42..f8936dfd4 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -54,6 +54,7 @@ def from_transducer( decoding_method: str = "greedy_search", max_active_paths: int = 4, hotwords_score: float = 1.5, + tokenize_hotwords: bool = True, blank_penalty: float = 0.0, hotwords_file: str = "", provider: str = "cpu", @@ -131,6 +132,9 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + tokenize_hotwords: + Whether to tokenize hotwords, true will tokenize hotwords in the engine + if false, you have to tokenize hotwords by yourself. temperature_scale: Temperature scaling for output symbol confidence estiamation. It affects only confidence values, the decoding uses the original @@ -222,6 +226,7 @@ def from_transducer( decoding_method=decoding_method, max_active_paths=max_active_paths, hotwords_score=hotwords_score, + tokenize_hotwords=tokenize_hotwords, hotwords_file=hotwords_file, blank_penalty=blank_penalty, temperature_scale=temperature_scale,