diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 4ba0a4a60..68f5d5050 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -130,6 +130,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + + recognizer_config.tokenize_hotwords = + SHERPA_ONNX_OR(config->tokenize_hotwords, true); + if (config->hotwords_buf && config->hotwords_buf_size > 0) { recognizer_config.hotwords_buf = std::string(config->hotwords_buf, config->hotwords_buf_size); @@ -467,6 +471,8 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + recognizer_config.tokenize_hotwords = + SHERPA_ONNX_OR(config->tokenize_hotwords, true); recognizer_config.blank_penalty = config->blank_penalty; diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 4b41a81a9..fe23c87e2 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -148,6 +148,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { /// Bonus score for each token in hotwords. float hotwords_score; + /// Whether to tokenize hotwords + bool tokenize_hotwords; + SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config; const char *rule_fsts; const char *rule_fars; @@ -438,6 +441,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { /// Bonus score for each token in hotwords. float hotwords_score; + + /// Whether to tokenize hotwords + bool tokenize_hotwords; + const char *rule_fsts; const char *rule_fars; float blank_penalty; diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index 05759ac5b..3bd0b9ba4 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -157,7 +157,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { std::vector> current; std::vector current_scores; if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), ¤t, ¤t_scores)) { + config_.tokenize_hotwords, bpe_encoder_.get(), ¤t, + ¤t_scores)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } @@ -262,7 +263,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); @@ -286,7 +288,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index f73e35ad6..a7cbe3b19 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -45,6 +45,11 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { "The bonus score for each token in context word/phrase. " "Used only when decoding_method is modified_beam_search"); + po->Register( + "tokenize-hotwords", &tokenize_hotwords, + "Whether to tokenize hotwords, default true, if false the input hotwords " + "should be tokenized into tokens"); + po->Register( "rule-fsts", &rule_fsts, "If not empty, it specifies fsts for inverse text normalization. " @@ -125,6 +130,7 @@ std::string OfflineRecognizerConfig::ToString() const { os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; os << "hotwords_score=" << hotwords_score << ", "; + os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", "; os << "blank_penalty=" << blank_penalty << ", "; os << "rule_fsts=\"" << rule_fsts << "\", "; os << "rule_fars=\"" << rule_fars << "\")"; diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 8f0b47a08..04171ea6a 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -37,6 +37,9 @@ struct OfflineRecognizerConfig { std::string hotwords_file; float hotwords_score = 1.5; + /// Whether to tokenize the input hotwords, normally should be true + /// if false, you have to tokenize hotwords by yourself. + bool tokenize_hotwords = true; float blank_penalty = 0.0; @@ -56,7 +59,7 @@ struct OfflineRecognizerConfig { const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, float hotwords_score, - float blank_penalty, const std::string &rule_fsts, + bool tokenize_hotwords, float blank_penalty, const std::string &rule_fsts, const std::string &rule_fars) : feat_config(feat_config), model_config(model_config), @@ -66,6 +69,7 @@ struct OfflineRecognizerConfig { max_active_paths(max_active_paths), hotwords_file(hotwords_file), hotwords_score(hotwords_score), + tokenize_hotwords(tokenize_hotwords), blank_penalty(blank_penalty), rule_fsts(rule_fsts), rule_fars(rule_fars) {} @@ -94,9 +98,10 @@ class OfflineRecognizer { /** Create a stream for decoding. * * @param The hotwords for this string, it might contain several hotwords, - * the hotwords are separated by "/". In each of the hotwords, there - * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). - * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO + * WORLD. if tokenize_hotwords is false, the hotwords should be + * tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look + * like: * * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" */ diff --git a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h index 475a90185..17710eeab 100644 --- a/sherpa-onnx/csrc/online-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-transducer-impl.h @@ -194,7 +194,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { std::vector> current; std::vector current_scores; if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), ¤t, ¤t_scores)) { + config_.tokenize_hotwords, bpe_encoder_.get(), ¤t, + ¤t_scores)) { SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", hotwords.c_str()); } @@ -420,7 +421,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); @@ -444,7 +446,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl { } if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_, - bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + config_.tokenize_hotwords, bpe_encoder_.get(), + &hotwords_, &boost_scores_)) { SHERPA_ONNX_LOGE( "Failed to encode some hotwords, skip them already, see logs above " "for details."); diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index c6b9399d8..842b1116a 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -94,6 +94,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) { "The file containing hotwords, one words/phrases per line, For example: " "HELLO WORLD" "你好世界"); + po->Register( + "tokenize-hotwords", &tokenize_hotwords, + "Whether to tokenize hotwords, default true, if false the input hotwords " + "should be tokenized into tokens"); po->Register("decoding-method", &decoding_method, "decoding method," "now support greedy_search and modified_beam_search."); @@ -181,6 +185,7 @@ std::string OnlineRecognizerConfig::ToString() const { os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_score=" << hotwords_score << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; + os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", "; os << "decoding_method=\"" << decoding_method << "\", "; os << "blank_penalty=" << blank_penalty << ", "; os << "temperature_scale=" << temperature_scale << ", "; diff --git a/sherpa-onnx/csrc/online-recognizer.h b/sherpa-onnx/csrc/online-recognizer.h index 45e0f4237..dec9034f2 100644 --- a/sherpa-onnx/csrc/online-recognizer.h +++ b/sherpa-onnx/csrc/online-recognizer.h @@ -95,6 +95,9 @@ struct OnlineRecognizerConfig { /// used only for modified_beam_search std::string hotwords_file; float hotwords_score = 1.5; + /// Whether to tokenize the input hotwords, normally should be true + /// if false, you have to tokenize hotwords by yourself. + bool tokenize_hotwords = true; float blank_penalty = 0.0; @@ -120,8 +123,9 @@ struct OnlineRecognizerConfig { const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config, bool enable_endpoint, const std::string &decoding_method, int32_t max_active_paths, const std::string &hotwords_file, - float hotwords_score, float blank_penalty, float temperature_scale, - const std::string &rule_fsts, const std::string &rule_fars) + float hotwords_score, bool tokenize_hotwords, float blank_penalty, + float temperature_scale, const std::string &rule_fsts, + const std::string &rule_fars) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), @@ -132,6 +136,7 @@ struct OnlineRecognizerConfig { max_active_paths(max_active_paths), hotwords_file(hotwords_file), hotwords_score(hotwords_score), + tokenize_hotwords(tokenize_hotwords), blank_penalty(blank_penalty), temperature_scale(temperature_scale), rule_fsts(rule_fsts), @@ -161,9 +166,10 @@ class OnlineRecognizer { /** Create a stream for decoding. * * @param The hotwords for this string, it might contain several hotwords, - * the hotwords are separated by "/". In each of the hotwords, there - * are cjkchars or bpes, the bpe/cjkchar are separated by space (" "). - * For example, hotwords I LOVE YOU and HELLO WORLD, looks like: + * the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO + * WORLD. if tokenize_hotwords is false, the hotwords should be + * tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look + * like: * * "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD" */ diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc index 0ad912df8..e1a9f091e 100644 --- a/sherpa-onnx/csrc/text2token-test.cc +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -43,12 +43,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar) { std::vector> ids; std::vector scores; - auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores); + auto r = + EncodeHotwords(iss, "cjkchar", sym_table, true, nullptr, &ids, &scores); std::vector> expected_ids( {{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}}); EXPECT_EQ(ids, expected_ids); + EXPECT_EQ(scores.size(), 0); + + // tokenize_hotwords = false + text = "世 界 人 民 大 团 结\n中 国 V S 美 国\n\n"; // Test blank lines also + iss.clear(); + iss.str(text); + + r = EncodeHotwords(iss, "cjkchar", sym_table, false, nullptr, &ids, &scores); + + EXPECT_EQ(ids, expected_ids); EXPECT_EQ(scores.size(), 0); } @@ -79,8 +90,8 @@ TEST(TEXT2TOKEN, TEST_bpe) { std::vector> ids; std::vector scores; - auto r = - EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); + auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(), + &ids, &scores); std::vector> expected_ids( {{22, 58, 24, 425}, {19, 370, 47}}); @@ -117,8 +128,8 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) { std::vector> ids; std::vector scores; - auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(), - &ids, &scores); + auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, true, + bpe_processor.get(), &ids, &scores); std::vector> expected_ids( {{1368, 1392, 557, 680, 275, 178, 475}, @@ -156,8 +167,8 @@ TEST(TEXT2TOKEN, TEST_bbpe) { std::vector> ids; std::vector scores; - auto r = - EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores); + auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(), + &ids, &scores); std::vector> expected_ids( {{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}}); diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index f40b67697..35f7efe22 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -70,6 +70,9 @@ static bool EncodeBase(const std::vector &lines, } } } + if (tmp_ids.empty()) { + continue; + } ids->push_back(std::move(tmp_ids)); tmp_ids = {}; tmp_scores.push_back(score); @@ -101,7 +104,7 @@ static bool EncodeBase(const std::vector &lines, } bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, - const SymbolTable &symbol_table, + const SymbolTable &symbol_table, bool tokenize_hotwords, const ssentencepiece::Ssentencepiece *bpe_encoder, std::vector> *hotwords, std::vector *boost_scores) { @@ -109,6 +112,14 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, std::string line; std::string word; + if (!tokenize_hotwords) { + while (std::getline(is, line)) { + lines.push_back(line); + } + return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores, + nullptr); + } + while (std::getline(is, line)) { std::string score; std::string phrase; diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index a9d59e8a2..6a280e670 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -27,7 +27,7 @@ namespace sherpa_onnx { * otherwise returns false. */ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, - const SymbolTable &symbol_table, + const SymbolTable &symbol_table, bool tokenize_hotwords, const ssentencepiece::Ssentencepiece *bpe_encoder_, std::vector> *hotwords_id, std::vector *boost_scores); diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 5e4b359b6..bc72a5f81 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -34,6 +34,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "hotwordsScore", "F"); ans.hotwords_score = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "tokenizeHotwords", "Z"); + ans.tokenize_hotwords = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index dbe205c4e..652990597 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -37,6 +37,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "hotwordsScore", "F"); ans.hotwords_score = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "tokenizeHotwords", "Z"); + ans.tokenize_hotwords = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 203278cb7..ce119aa7e 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -63,6 +63,7 @@ data class OfflineRecognizerConfig( var maxActivePaths: Int = 4, var hotwordsFile: String = "", var hotwordsScore: Float = 1.5f, + var tokenizeHotwords: Boolean = true, var ruleFsts: String = "", var ruleFars: String = "", var blankPenalty: Float = 0.0f, diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index 7ddefdf32..35d5959f1 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -69,6 +69,7 @@ data class OnlineRecognizerConfig( var maxActivePaths: Int = 4, var hotwordsFile: String = "", var hotwordsScore: Float = 1.5f, + var tokenizeHotwords: Boolean = true, var ruleFsts: String = "", var ruleFars: String = "", var blankPenalty: Float = 0.0f, diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index 2a603e08f..011d8a077 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -17,14 +17,15 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def(py::init(), + bool, float, const std::string &, const std::string &>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0, - py::arg("rule_fsts") = "", py::arg("rule_fars") = "") + py::arg("hotwords_score") = 1.5, py::arg("tokenize_hotwords") = true, + py::arg("blank_penalty") = 0.0, py::arg("rule_fsts") = "", + py::arg("rule_fars") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -33,6 +34,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) { .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords) .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("rule_fsts", &PyClass::rule_fsts) .def_readwrite("rule_fars", &PyClass::rule_fars) diff --git a/sherpa-onnx/python/csrc/online-recognizer.cc b/sherpa-onnx/python/csrc/online-recognizer.cc index fe6cd454a..1e587f007 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.cc +++ b/sherpa-onnx/python/csrc/online-recognizer.cc @@ -54,20 +54,21 @@ static void PybindOnlineRecognizerResult(py::module *m) { static void PybindOnlineRecognizerConfig(py::module *m) { using PyClass = OnlineRecognizerConfig; py::class_(*m, "OnlineRecognizerConfig") - .def(py::init(), - py::arg("feat_config"), py::arg("model_config"), - py::arg("lm_config") = OnlineLMConfig(), - py::arg("endpoint_config") = EndpointConfig(), - py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), - py::arg("enable_endpoint"), py::arg("decoding_method"), - py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", - py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0, - py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "", - py::arg("rule_fars") = "") + .def( + py::init(), + py::arg("feat_config"), py::arg("model_config"), + py::arg("lm_config") = OnlineLMConfig(), + py::arg("endpoint_config") = EndpointConfig(), + py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(), + py::arg("enable_endpoint"), py::arg("decoding_method"), + py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 0, py::arg("tokenize_hotwords") = true, + py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0, + py::arg("rule_fsts") = "", py::arg("rule_fars") = "") .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) @@ -78,6 +79,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) { .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) .def_readwrite("hotwords_score", &PyClass::hotwords_score) + .def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords) .def_readwrite("blank_penalty", &PyClass::blank_penalty) .def_readwrite("temperature_scale", &PyClass::temperature_scale) .def_readwrite("rule_fsts", &PyClass::rule_fsts) diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index e96271a58..6c08bc8a8 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -50,6 +50,7 @@ def from_transducer( max_active_paths: int = 4, hotwords_file: str = "", hotwords_score: float = 1.5, + tokenize_hotwords: bool = True, blank_penalty: float = 0.0, modeling_unit: str = "cjkchar", bpe_vocab: str = "", @@ -97,6 +98,9 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + tokenize_hotwords: + Whether to tokenize hotwords, true will tokenize hotwords in the engine + if false, you have to tokenize hotwords by yourself. blank_penalty: The penalty applied on blank symbol during decoding. modeling_unit: @@ -168,6 +172,7 @@ def from_transducer( max_active_paths=max_active_paths, hotwords_file=hotwords_file, hotwords_score=hotwords_score, + tokenize_hotwords=tokenize_hotwords, blank_penalty=blank_penalty, rule_fsts=rule_fsts, rule_fars=rule_fars, diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 321f1cdff..38b0bad1e 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -57,6 +57,7 @@ def from_transducer( decoding_method: str = "greedy_search", max_active_paths: int = 4, hotwords_score: float = 1.5, + tokenize_hotwords: bool = True, blank_penalty: float = 0.0, hotwords_file: str = "", model_type: str = "", @@ -147,6 +148,9 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + tokenize_hotwords: + Whether to tokenize hotwords, true will tokenize hotwords in the engine + if false, you have to tokenize hotwords by yourself. temperature_scale: Temperature scaling for output symbol confidence estiamation. It affects only confidence values, the decoding uses the original @@ -287,6 +291,7 @@ def from_transducer( decoding_method=decoding_method, max_active_paths=max_active_paths, hotwords_score=hotwords_score, + tokenize_hotwords=tokenize_hotwords, hotwords_file=hotwords_file, blank_penalty=blank_penalty, temperature_scale=temperature_scale,