diff --git a/CMakeLists.txt b/CMakeLists.txt index 578dc78d8..36ee6fd36 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -328,6 +328,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) endif() endif() +include(cppinyin) include(kaldi-native-fbank) include(kaldi-decoder) include(onnxruntime) diff --git a/cmake/cppinyin.cmake b/cmake/cppinyin.cmake new file mode 100644 index 000000000..29dec3cd9 --- /dev/null +++ b/cmake/cppinyin.cmake @@ -0,0 +1,63 @@ +function(download_cppinyin) + include(FetchContent) + + set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") + set(cppinyin_URL2 "https://hub.nuaa.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") + set(cppinyin_HASH "SHA256=3659bc0c28d17d41ce932807c1cdc1da8c861e6acee969b5844d0d0a3c5ef78b") + + # If you don't have access to the Internet, + # please pre-download cppinyin + set(possible_file_locations + $ENV{HOME}/Downloads/cppinyin-0.1.tar.gz + ${CMAKE_SOURCE_DIR}/cppinyin-0.1.tar.gz + ${CMAKE_BINARY_DIR}/cppinyin-0.1.tar.gz + /tmp/cppinyin-0.1.tar.gz + /star-fj/fangjun/download/github/cppinyin-0.1.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(cppinyin_URL "${f}") + file(TO_CMAKE_PATH "${cppinyin_URL}" cppinyin_URL) + message(STATUS "Found local downloaded cppinyin: ${cppinyin_URL}") + set(cppinyin_URL2) + break() + endif() + endforeach() + + set(CPPINYIN_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + set(CPPINYIN_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(cppinyin + URL + ${cppinyin_URL} + ${cppinyin_URL2} + URL_HASH + ${cppinyin_HASH} + ) + + FetchContent_GetProperties(cppinyin) + if(NOT cppinyin_POPULATED) + message(STATUS "Downloading cppinyin ${cppinyin_URL}") + FetchContent_Populate(cppinyin) + endif() + message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}") + add_subdirectory(${cppinyin_SOURCE_DIR} ${cppinyin_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(cppinyin_core + PUBLIC + ${cppinyin_SOURCE_DIR}/ + ) + + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) + install(TARGETS cppinyin_core DESTINATION ..) + else() + install(TARGETS cppinyin_core DESTINATION lib) + endif() + + if(WIN32 AND BUILD_SHARED_LIBS) + install(TARGETS cppinyin_core DESTINATION bin) + endif() +endfunction() + +download_cppinyin() diff --git a/python-api-examples/keyword-spotter-from-microphone.py b/python-api-examples/keyword-spotter-from-microphone.py index 65a59fca6..3df1d9b8a 100755 --- a/python-api-examples/keyword-spotter-from-microphone.py +++ b/python-api-examples/keyword-spotter-from-microphone.py @@ -95,15 +95,39 @@ def get_args(): """, ) + parser.add_argument( + "--modeling-unit", + type=str, + help="""The modeling unit of the model, valid values are bpe (for English model) + and ppinyin (For Chinese model). + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + help="""A simple format of bpe model, you can get it from the sentencepiece + generated folder. Used to tokenize the keywords into token ids. Used only + when modeling unit is bpe. + """, + ) + + parser.add_argument( + "--lexicon", + type=str, + help="""The lexicon used to tokenize the keywords into token ids. Used + only when modeling unit is ppinyin. + """, + ) + parser.add_argument( "--keywords-file", type=str, help=""" - The file containing keywords, one words/phrases per line, and for each - phrase the bpe/cjkchar/pinyin are separated by a space. For example: + The file containing keywords, one words/phrases per line. For example: - ▁HE LL O ▁WORLD - x iǎo ài t óng x ué + HELLO WORLD + 小爱同学 """, ) @@ -164,6 +188,9 @@ def main(): keywords_score=args.keywords_score, keywords_threshold=args.keywords_threshold, num_trailing_blanks=args.num_trailing_blanks, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, + lexicon=args.lexicon, provider=args.provider, ) diff --git a/python-api-examples/keyword-spotter.py b/python-api-examples/keyword-spotter.py index 1b1de77e3..a7fcf50db 100755 --- a/python-api-examples/keyword-spotter.py +++ b/python-api-examples/keyword-spotter.py @@ -80,15 +80,39 @@ def get_args(): """, ) + parser.add_argument( + "--modeling-unit", + type=str, + help="""The modeling unit of the model, valid values are bpe (for English model) + and ppinyin (For Chinese model). + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + help="""A simple format of bpe model, you can get it from the sentencepiece + generated folder. Used to tokenize the keywords into token ids. Used only + when modeling unit is bpe. + """, + ) + + parser.add_argument( + "--lexicon", + type=str, + help="""The lexicon used to tokenize the keywords into token ids. Used + only when modeling unit is ppinyin. + """, + ) + parser.add_argument( "--keywords-file", type=str, help=""" - The file containing keywords, one words/phrases per line, and for each - phrase the bpe/cjkchar/pinyin are separated by a space. For example: + The file containing keywords, one words/phrases per line. For example: - ▁HE LL O ▁WORLD - x iǎo ài t óng x ué + HELLO WORLD + 小爱同学 """, ) @@ -183,6 +207,9 @@ def main(): keywords_score=args.keywords_score, keywords_threshold=args.keywords_threshold, num_trailing_blanks=args.num_trailing_blanks, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, + lexicon=args.lexicon, provider=args.provider, ) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 3e6526563..f01bea8ee 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -199,6 +199,7 @@ if(ANDROID_NDK) endif() target_link_libraries(sherpa-onnx-core + cppinyin_core kaldi-native-fbank-core kaldi-decoder-core ssentencepiece_core diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 759639184..61327034e 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -19,6 +19,7 @@ #include "android/asset_manager_jni.h" #endif +#include "cppinyin/csrc/cppinyin.h" #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/keyword-spotter-impl.h" #include "sherpa-onnx/csrc/keyword-spotter.h" @@ -27,6 +28,7 @@ #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/transducer-keyword-decoder.h" #include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -78,6 +80,17 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { unk_id_ = sym_[""]; } + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (config_.model_config.modeling_unit == "ppinyin" && + !config_.model_config.lexicon.empty()) { + pinyin_encoder_ = std::make_unique( + config_.model_config.lexicon); + } + model_->SetFeatureDim(config.feat_config.feature_dim); if (config.keywords_buf.empty()) { @@ -103,6 +116,19 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { model_->SetFeatureDim(config.feat_config.feature_dim); + if (!config_.model_config.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + + if (config_.model_config.modeling_unit == "ppinyin" && + !config_.model_config.lexicon.empty()) { + auto buf = ReadFile(mgr, config_.model_config.lexicon); + std::istringstream iss(std::string(buf.begin(), buf.end())); + pinyin_encoder_ = std::make_unique(iss); + } + InitKeywords(mgr); decoder_ = std::make_unique( @@ -128,8 +154,9 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::vector current_scores; std::vector current_thresholds; - if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores, - ¤t_thresholds)) { + if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), pinyin_encoder_.get(), ¤t_ids, + ¤t_kws, ¤t_scores, ¤t_thresholds)) { SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); return nullptr; } @@ -269,7 +296,9 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { private: void InitKeywords(std::istream &is) { - if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_, + if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), pinyin_encoder_.get(), + &keywords_id_, &keywords_, &boost_scores_, &thresholds_)) { SHERPA_ONNX_LOGE("Encode keywords failed."); exit(-1); @@ -339,6 +368,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::vector thresholds_; std::vector keywords_; ContextGraphPtr keywords_graph_; + std::unique_ptr bpe_encoder_; + std::unique_ptr pinyin_encoder_; std::unique_ptr model_; std::unique_ptr decoder_; SymbolTable sym_; diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index d1bf6d63b..c426d1dc8 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -82,10 +82,9 @@ void KeywordSpotterConfig::Register(ParseOptions *po) { "The acoustic threshold (probability) to trigger the keywords."); po->Register( "keywords-file", &keywords_file, - "The file containing keywords, one word/phrase per line, and for each" - "phrase the bpe/cjkchar are separated by a space. For example: " - "▁HE LL O ▁WORLD" - "你 好 世 界"); + "The file containing keywords, one word/phrase per line. For example: " + "HELLO WORLD" + "你好世界"); } bool KeywordSpotterConfig::Validate() const { diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 5592c8d0a..1d4a18411 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -30,11 +30,12 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("debug", &debug, "true to print model information while loading it."); - po->Register("modeling-unit", &modeling_unit, - "The modeling unit of the model, commonly used units are bpe, " - "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " - "hotwords are provided, we need it to encode the hotwords into " - "token sequence."); + po->Register( + "modeling-unit", &modeling_unit, + "The modeling unit of the model, commonly used units are bpe, " + "cjkchar, cjkchar+bpe, ppinyin, etc. Currently, it is needed only when " + "hotwords are provided, we need it to encode the hotwords into " + "token sequence."); po->Register("bpe-vocab", &bpe_vocab, "The vocabulary generated by google's sentencepiece program. " @@ -43,6 +44,10 @@ void OnlineModelConfig::Register(ParseOptions *po) { "your bpe model is generated. Only used when hotwords provided " "and the modeling unit is bpe or cjkchar+bpe"); + po->Register("lexicon", &lexicon, + "The lexicon used to encode words into tokens." + "Only used for keyword spotting now"); + po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " "Valid values are: conformer, lstm, zipformer, zipformer2, " @@ -80,6 +85,14 @@ bool OnlineModelConfig::Validate() const { } } + if (!modeling_unit.empty() && + (modeling_unit == "fpinyin" || modeling_unit == "ppinyin")) { + if (!FileExists(lexicon)) { + SHERPA_ONNX_LOGE("lexicon: %s does not exist", lexicon.c_str()); + return false; + } + } + if (!paraformer.encoder.empty()) { return paraformer.Validate(); } @@ -119,6 +132,7 @@ std::string OnlineModelConfig::ToString() const { os << "debug=" << (debug ? "True" : "False") << ", "; os << "model_type=\"" << model_type << "\", "; os << "modeling_unit=\"" << modeling_unit << "\", "; + os << "lexicon=\"" << lexicon << "\", "; os << "bpe_vocab=\"" << bpe_vocab << "\")"; return os.str(); diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index a920512d8..4663bddd5 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -42,9 +42,17 @@ struct OnlineModelConfig { // - cjkchar // - bpe // - cjkchar+bpe + // - fpinyin + // - ppinyin std::string modeling_unit = "cjkchar"; + // For encoding words into tokens + // Used only for models trained with bpe std::string bpe_vocab; + // For encoding words into tokens + // Used for models trained with pinyin or phone + std::string lexicon; + /// if tokens_buf is non-empty, /// the tokens will be loaded from the buffer instead of from the /// "tokens" file @@ -60,7 +68,7 @@ struct OnlineModelConfig { const std::string &tokens, int32_t num_threads, int32_t warm_up, bool debug, const std::string &model_type, const std::string &modeling_unit, - const std::string &bpe_vocab) + const std::string &bpe_vocab, const std::string &lexicon) : transducer(transducer), paraformer(paraformer), wenet_ctc(wenet_ctc), @@ -73,7 +81,8 @@ struct OnlineModelConfig { debug(debug), model_type(model_type), modeling_unit(modeling_unit), - bpe_vocab(bpe_vocab) {} + bpe_vocab(bpe_vocab), + lexicon(lexicon) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc index 0ad912df8..de33b70a0 100644 --- a/sherpa-onnx/csrc/text2token-test.cc +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -6,6 +6,7 @@ #include #include +#include "cppinyin/csrc/cppinyin.h" #include "gtest/gtest.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/utils.h" @@ -167,4 +168,99 @@ TEST(TEXT2TOKEN, TEST_bbpe) { EXPECT_EQ(scores, expected_scores); } +TEST(TEXT2TOKEN, TEST_keyword_bpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_en.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_en.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "HELLO WORLD :2.0 @FUCK_WORLD\nI LOVE YOU #0.25"; + + std::istringstream iss(text); + + std::vector> ids; + std::vector keywords; + std::vector scores; + std::vector thresholds; + + auto r = EncodeKeywords(iss, "bpe", sym_table, bpe_processor.get(), nullptr, + &ids, &keywords, &scores, &thresholds); + + std::vector> expected_ids( + {{22, 58, 24, 425}, {19, 370, 47}}); + EXPECT_EQ(ids, expected_ids); + + std::vector expected_keywords({"FUCK WORLD", "I LOVE YOU"}); + EXPECT_EQ(keywords, expected_keywords); + + std::vector expected_scores({2.0, 0}); + EXPECT_EQ(scores, expected_scores); + + std::vector expected_thresholds({0, 0.25}); + EXPECT_EQ(thresholds, expected_thresholds); +} + +TEST(TEXT2TOKEN, TEST_keyword_ppinyin) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_pinyin.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/pinyin.raw"; + std::string lexicon = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(lexicon).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto py_encoder = std::make_unique(lexicon); + + std::string text = "世界人民大团结 :2.0\n美国大选 #0.4\n中国很美 @中国很漂亮"; + + std::istringstream iss(text); + + std::vector> ids; + std::vector keywords; + std::vector scores; + std::vector thresholds; + + auto r = EncodeKeywords(iss, "ppinyin", sym_table, nullptr, py_encoder.get(), + &ids, &keywords, &scores, &thresholds); + + std::vector expected_keywords( + {"世界人民大团结", "美国大选", "中国很漂亮"}); + EXPECT_EQ(keywords, expected_keywords); + + std::vector expected_scores({2.0, 0, 0}); + EXPECT_EQ(scores, expected_scores); + + std::vector expected_thresholds({0, 0.4, 0}); + EXPECT_EQ(thresholds, expected_thresholds); + + std::vector> expected_ids( + {{13, 36, 24, 155, 39, 41, 58, 137, 53, 71, 77, 114, 24, 138}, + {58, 125, 43, 66, 53, 71, 48, 44}, + {10, 50, 43, 66, 75, 148, 58, 125}}); + EXPECT_EQ(ids, expected_ids); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index f40b67697..47a70cda6 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -6,6 +6,7 @@ #include #include +#include // NOLINT #include #include #include @@ -58,6 +59,7 @@ static bool EncodeBase(const std::vector &lines, break; case '@': // the original keyword string phrase = word.substr(1); + phrase = std::regex_replace(phrase, std::regex("_"), " "); has_phrases = true; break; default: @@ -187,16 +189,104 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, nullptr); } -bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, +bool EncodeKeywords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder, + const cppinyin::PinyinEncoder *pinyin_encoder, std::vector> *keywords_id, std::vector *keywords, std::vector *boost_scores, std::vector *threshold) { std::vector lines; std::string line; + std::string word; + while (std::getline(is, line)) { - lines.push_back(line); + std::string score; + std::string phrase; + std::string threshold; + std::string custom_phrase; + + std::ostringstream oss; + std::istringstream iss(line); + while (iss >> word) { + switch (word[0]) { + case ':': // boosting score for current keyword + score = word; + break; + case '#': // triggering threshold for current keyword + threshold = word; + break; + case '@': // the customize phrase for current keyword + custom_phrase = word; + break; + default: + if (!score.empty() || !threshold.empty() || !custom_phrase.empty()) { + SHERPA_ONNX_LOGE( + "Boosting score, threshold and customize phrase should be put " + "after the words/phrase, given %s.", + line.c_str()); + return false; + } + oss << " " << word; + break; + } + } + + phrase = oss.str(); + if (phrase.empty()) { + continue; + } else { + phrase = phrase.substr(1); + } + + std::istringstream piss(phrase); + oss.clear(); + oss.str(""); + std::ostringstream poss; + while (piss >> word) { + poss << "_" << word; + if (modeling_unit == "bpe") { + std::vector bpes; + bpe_encoder->Encode(word, &bpes); + for (const auto &bpe : bpes) { + oss << " " << bpe; + } + } else { + if (modeling_unit != "ppinyin") { + SHERPA_ONNX_LOGE( + "modeling_unit should be one of bpe, ppinyin, " + "given " + "%s", + modeling_unit.c_str()); + exit(-1); + } + std::vector pinyins; + pinyin_encoder->Encode(word, &pinyins, true /* tone */, + true /* partial */); + for (const auto &pinyin : pinyins) { + oss << " " << pinyin; + } + } + } + std::string encoded_phrase = oss.str().substr(1); + oss.clear(); + oss.str(""); + oss << encoded_phrase; + if (!score.empty()) { + oss << " " << score; + } + if (!threshold.empty()) { + oss << " " << threshold; + } + if (!custom_phrase.empty()) { + oss << " " << custom_phrase; + } else { + oss << " @" << poss.str().substr(1); + } + lines.push_back(oss.str()); } + return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores, threshold); } diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index a9d59e8a2..fdff06b3d 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -7,6 +7,7 @@ #include #include +#include "cppinyin/csrc/cppinyin.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "ssentencepiece/csrc/ssentencepiece.h" @@ -51,7 +52,10 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, * @return If all the symbols from ``is`` are in the symbol_table, returns true * otherwise returns false. */ -bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, +bool EncodeKeywords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder, + const cppinyin::PinyinEncoder *pinyin_encoder, std::vector> *keywords_id, std::vector *keywords, std::vector *boost_scores, diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc index 4ac80a294..6e8deb70a 100644 --- a/sherpa-onnx/jni/keyword-spotter.cc +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -103,6 +103,24 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { ans.model_config.model_type = p; env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.modeling_unit = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + return ans; } diff --git a/sherpa-onnx/kotlin-api/KeywordSpotter.kt b/sherpa-onnx/kotlin-api/KeywordSpotter.kt index ea2143613..ac343e17c 100644 --- a/sherpa-onnx/kotlin-api/KeywordSpotter.kt +++ b/sherpa-onnx/kotlin-api/KeywordSpotter.kt @@ -112,6 +112,8 @@ fun getKwsModelConfig(type: Int): OnlineModelConfig? { ), tokens = "$modelDir/tokens.txt", modelType = "zipformer2", + modelingUnit = "ppinyin", + lexicon = "$modelDir/pinyin.dict", ) } @@ -125,6 +127,8 @@ fun getKwsModelConfig(type: Int): OnlineModelConfig? { ), tokens = "$modelDir/tokens.txt", modelType = "zipformer2", + modelingUnit = "bpe", + bpeVocab = "$modelDir/bpe.vocab", ) } diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index 7ddefdf32..39766e3e9 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -45,6 +45,7 @@ data class OnlineModelConfig( var modelType: String = "", var modelingUnit: String = "", var bpeVocab: String = "", + var lexicon: String = "", ) data class OnlineLMConfig( diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index 4ea13fd60..16e9e6e72 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -33,20 +33,20 @@ void PybindOnlineModelConfig(py::module *m) { const OnlineParaformerModelConfig &, const OnlineWenetCtcModelConfig &, const OnlineZipformer2CtcModelConfig &, - const OnlineNeMoCtcModelConfig &, - const ProviderConfig &, - const std::string &, int32_t, int32_t, - bool, const std::string &, const std::string &, - const std::string &>(), + const OnlineNeMoCtcModelConfig &, const ProviderConfig &, + const std::string &, int32_t, int32_t, bool, + const std::string &, const std::string &, + const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(), py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(), - py::arg("provider_config") = ProviderConfig(), - py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0, + py::arg("provider_config") = ProviderConfig(), py::arg("tokens"), + py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("debug") = false, py::arg("model_type") = "", - py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "") + py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "", + py::arg("lexicon") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) @@ -60,6 +60,7 @@ void PybindOnlineModelConfig(py::module *m) { .def_readwrite("model_type", &PyClass::model_type) .def_readwrite("modeling_unit", &PyClass::modeling_unit) .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) + .def_readwrite("lexicon", &PyClass::lexicon) .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index 66d716984..1932c46c3 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -41,6 +41,9 @@ def __init__( keywords_score: float = 1.0, keywords_threshold: float = 0.25, num_trailing_blanks: int = 1, + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", + lexicon: str = "", provider: str = "cpu", device: int = 0, ): @@ -85,6 +88,17 @@ def __init__( The number of trailing blanks a keyword should be followed. Setting to a larger value (e.g. 8) when your keywords has overlapping tokens between each other. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, ppinyin, etc. + We need it to encode the keywords into token sequence. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when the modeling unit is bpe. + lexicon: + The lexicon used to tokenize keywords into token sequences, Only used + when the modeling unit is pinyin or phone. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. device: @@ -104,14 +118,17 @@ def __init__( ) provider_config = ProviderConfig( - provider=provider, - device = device, + provider=provider, + device=device, ) model_config = OnlineModelConfig( transducer=transducer_config, tokens=tokens, num_threads=num_threads, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, + lexicon=lexicon, provider_config=provider_config, ) diff --git a/sherpa-onnx/python/tests/test_keyword_spotter.py b/sherpa-onnx/python/tests/test_keyword_spotter.py index f4d79830a..4775e5dee 100755 --- a/sherpa-onnx/python/tests/test_keyword_spotter.py +++ b/sherpa-onnx/python/tests/test_keyword_spotter.py @@ -60,6 +60,9 @@ def test_zipformer_transducer_en(self): tokens = ( f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt" ) + bpe_vocab = ( + f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/bpe.vocab" + ) keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav" wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav" @@ -74,6 +77,8 @@ def test_zipformer_transducer_en(self): tokens=tokens, num_threads=1, keywords_file=keywords_file, + modeling_unit="bpe", + bpe_vocab=bpe_vocab, provider="cpu", ) streams = [] @@ -119,6 +124,9 @@ def test_zipformer_transducer_cn(self): tokens = ( f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" ) + lexicon = ( + f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/pinyin.dict" + ) keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav" @@ -134,6 +142,8 @@ def test_zipformer_transducer_cn(self): tokens=tokens, num_threads=1, keywords_file=keywords_file, + modeling_unit="ppinyin", + lexicon=lexicon, provider="cpu", ) streams = []