From 6a348f9bd4054a83acc634a889e2e42165fad2fa Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 21 May 2024 11:23:32 +0800 Subject: [PATCH 1/4] Add cppinyin CMakeList --- cmake/cppinyin.cmake | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 cmake/cppinyin.cmake diff --git a/cmake/cppinyin.cmake b/cmake/cppinyin.cmake new file mode 100644 index 000000000..c2019c497 --- /dev/null +++ b/cmake/cppinyin.cmake @@ -0,0 +1,63 @@ +function(download_cppinyin) + include(FetchContent) + + set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") + set(cppinyin_URL2 "https://hub.nauu.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") + set(cppinyin_HASH "SHA256=3659bc0c28d17d41ce932807c1cdc1da8c861e6acee969b5844d0d0a3c5ef78b") + + # If you don't have access to the Internet, + # please pre-download cppinyin + set(possible_file_locations + $ENV{HOME}/Downloads/cppinyin-0.1.tar.gz + ${CMAKE_SOURCE_DIR}/cppinyin-0.1.tar.gz + ${CMAKE_BINARY_DIR}/cppinyin-0.1.tar.gz + /tmp/cppinyin-0.1.tar.gz + /star-fj/fangjun/download/github/cppinyin-0.1.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(cppinyin_URL "${f}") + file(TO_CMAKE_PATH "${cppinyin_URL}" cppinyin_URL) + message(STATUS "Found local downloaded cppinyin: ${cppinyin_URL}") + set(cppinyin_URL2) + break() + endif() + endforeach() + + set(CPPINYIN_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + set(CPPINYIN_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(cppinyin + URL + ${cppinyin_URL} + ${cppinyin_URL2} + URL_HASH + ${cppinyin_HASH} + ) + + FetchContent_GetProperties(cppinyin) + if(NOT cppinyin_POPULATED) + message(STATUS "Downloading cppinyin ${cppinyin_URL}") + FetchContent_Populate(cppinyin) + endif() + message(STATUS "cppinyin is downloaded to ${cppinyin_SOURCE_DIR}") + add_subdirectory(${cppinyin_SOURCE_DIR} ${cppinyin_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(cppinyin_core + PUBLIC + ${cppinyin_SOURCE_DIR}/ + ) + + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) + install(TARGETS cppinyin_core DESTINATION ..) + else() + install(TARGETS cppinyin_core DESTINATION lib) + endif() + + if(WIN32 AND BUILD_SHARED_LIBS) + install(TARGETS cppinyin_core DESTINATION bin) + endif() +endfunction() + +download_cppinyin() From 68405f2a140e4707fe2fe67b2cee2445b28a6ba9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 27 May 2024 19:19:47 +0800 Subject: [PATCH 2/4] Add lexicon in online-model --- CMakeLists.txt | 1 + sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/online-model-config.cc | 13 ++++ sherpa-onnx/csrc/online-model-config.h | 13 +++- sherpa-onnx/csrc/text2token-test.cc | 79 +++++++++++++++++++++++ sherpa-onnx/csrc/utils.cc | 83 ++++++++++++++++++++++++- sherpa-onnx/csrc/utils.h | 5 +- 7 files changed, 191 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a3e4ffef4..8b9bc2969 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) endif() endif() +include(cppinyin) include(kaldi-native-fbank) include(kaldi-decoder) include(onnxruntime) diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 4ed2cb119..a939caa3b 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -163,6 +163,7 @@ if(ANDROID_NDK) endif() target_link_libraries(sherpa-onnx-core + cppinyin_core kaldi-native-fbank-core kaldi-decoder-core ssentencepiece_core diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 5ea24babe..dff45692a 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -45,6 +45,10 @@ void OnlineModelConfig::Register(ParseOptions *po) { "your bpe model is generated. Only used when hotwords provided " "and the modeling unit is bpe or cjkchar+bpe"); + po->Register("lexicon", &lexicon, + "The lexicon used to encode words into tokens." + "Only used for keyword spotting now"); + po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " "Valid values are: conformer, lstm, zipformer, zipformer2, " @@ -71,6 +75,14 @@ bool OnlineModelConfig::Validate() const { } } + if (!modeling_unit.empty() && + (modeling_unit == "fpinyin" || modeling_unit == "ppinyin")) { + if (!FileExists(lexicon)) { + SHERPA_ONNX_LOGE("lexicon: %s does not exist", lexicon.c_str()); + return false; + } + } + if (!paraformer.encoder.empty()) { return paraformer.Validate(); } @@ -106,6 +118,7 @@ std::string OnlineModelConfig::ToString() const { os << "provider=\"" << provider << "\", "; os << "model_type=\"" << model_type << "\", "; os << "modeling_unit=\"" << modeling_unit << "\", "; + os << "lexicon=\"" << lexicon << "\", "; os << "bpe_vocab=\"" << bpe_vocab << "\")"; return os.str(); diff --git a/sherpa-onnx/csrc/online-model-config.h b/sherpa-onnx/csrc/online-model-config.h index 1509bd5b0..223599571 100644 --- a/sherpa-onnx/csrc/online-model-config.h +++ b/sherpa-onnx/csrc/online-model-config.h @@ -41,9 +41,17 @@ struct OnlineModelConfig { // - cjkchar // - bpe // - cjkchar+bpe + // - fpinyin + // - ppinyin std::string modeling_unit = "cjkchar"; + // For encoding words into tokens + // Used only for models trained with bpe std::string bpe_vocab; + // For encoding words into tokens + // Used for models trained with pinyin or phone + std::string lexicon; + OnlineModelConfig() = default; OnlineModelConfig(const OnlineTransducerModelConfig &transducer, const OnlineParaformerModelConfig ¶former, @@ -54,7 +62,7 @@ struct OnlineModelConfig { int32_t warm_up, bool debug, const std::string &provider, const std::string &model_type, const std::string &modeling_unit, - const std::string &bpe_vocab) + const std::string &bpe_vocab, const std::string &lexicon) : transducer(transducer), paraformer(paraformer), wenet_ctc(wenet_ctc), @@ -67,7 +75,8 @@ struct OnlineModelConfig { provider(provider), model_type(model_type), modeling_unit(modeling_unit), - bpe_vocab(bpe_vocab) {} + bpe_vocab(bpe_vocab), + lexicon(lexicon) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc index ef07797db..21950fec6 100644 --- a/sherpa-onnx/csrc/text2token-test.cc +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -149,4 +149,83 @@ TEST(TEXT2TOKEN, TEST_bbpe) { EXPECT_EQ(ids, expected_ids); } +TEST(TEXT2TOKEN, TEST_keyword_bpe) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_en.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_en.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "HELLO WORLD :2.0 @FUCK_WORLD\nI LOVE YOU #0.25"; + + std::istringstream iss(text); + + std::vector> ids; + std::vector keywords; + std::vector scores; + std::vector thresholds; + + auto r = EncodeKeywords(iss, "bpe", sym_table, bpe_processor.get(), nullptr, + &ids, &keywords, &scores, &thresholds); + + std::vector> expected_ids( + {{22, 58, 24, 425}, {19, 370, 47}}); + EXPECT_EQ(ids, expected_ids); + + std::vector expected_keywords({"FUCK WORLD", "I LOVE YOU"}); + EXPECT_EQ(keywords, expected_keywords); + + std::vector expected_scores({2.0, 0}); + EXPECT_EQ(scores, expected_scores); + + std::vector expected_thresholds({0, 0.25}); + EXPECT_EQ(thresholds, expected_thresholds); +} + +TEST(TEXT2TOKEN, TEST_keyword_ppinyin) { + std::ostringstream oss; + oss << dir << "/text2token/tokens_en.txt"; + std::string tokens = oss.str(); + oss.clear(); + oss.str(""); + oss << dir << "/text2token/bpe_en.vocab"; + std::string bpe = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + SHERPA_ONNX_LOGE( + "No test data found, skipping TEST_bpe()." + "You can download the test data by: " + "git clone https://github.com/pkufool/sherpa-test-data.git " + "/tmp/sherpa-test-data"); + return; + } + + auto sym_table = SymbolTable(tokens); + auto bpe_processor = std::make_unique(bpe); + + std::string text = "HELLO WORLD\nI LOVE YOU"; + + std::istringstream iss(text); + + std::vector> ids; + + auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); + + std::vector> expected_ids( + {{22, 58, 24, 425}, {19, 370, 47}}); + EXPECT_EQ(ids, expected_ids); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index 6363f03c4..35f4216c3 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -6,6 +6,7 @@ #include #include +#include // NOLINT #include #include #include @@ -59,6 +60,7 @@ static bool EncodeBase(const std::vector &lines, break; case '@': // the original keyword string phrase = word.substr(1); + phrase = std::regex_replace(phrase, std::regex("_"), " "); has_phrases = true; break; default: @@ -181,15 +183,94 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, } bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder, + const cppinyin::PinyinEncoder *pinyin_encoder, std::vector> *keywords_id, std::vector *keywords, std::vector *boost_scores, std::vector *threshold) { std::vector lines; std::string line; + std::string word; + while (std::getline(is, line)) { - lines.push_back(line); + std::string score; + std::string phrase; + std::string threshold; + std::string custom_phrase; + + std::ostringstream oss; + std::istringstream iss(line); + while (iss >> word) { + switch (word[0]) { + case ':': // boosting score for current keyword + score = word; + break; + case '#': // triggering threshold for current keyword + threshold = word; + break; + case '@': // the customize phrase for current keyword + custom_phrase = word; + break; + default: + if (!score.empty() || !threshold.empty() || !custom_phrase.empty()) { + SHERPA_ONNX_LOGE( + "Boosting score, threshold and customize phrase should be put " + "after the words/phrase, given %s.", + line.c_str()); + return false; + } + oss << " " << word; + break; + } + } + phrase = oss.str().substr(1); + std::istringstream piss(phrase); + oss.clear(); + oss.str(""); + std::ostringstream poss; + while (piss >> word) { + poss << "_" << word; + if (modeling_unit == "bpe") { + std::vector bpes; + bpe_encoder->Encode(word, &bpes); + for (const auto &bpe : bpes) { + oss << " " << bpe; + } + } else { + if (modeling_unit != "ppinyin") { + SHERPA_ONNX_LOGE( + "modeling_unit should be one of bpe, ppinyin, " + "given " + "%s", + modeling_unit.c_str()); + exit(-1); + } + std::vector pinyins; + pinyin_encoder->Encode(word, &pinyins); + for (const auto &pinyin : pinyins) { + oss << " " << pinyin; + } + } + } + std::string encoded_phrase = oss.str().substr(1); + oss.clear(); + oss.str(""); + oss << encoded_phrase; + if (!score.empty()) { + oss << " " << score; + } + if (!threshold.empty()) { + oss << " " << threshold; + } + if (!custom_phrase.empty()) { + oss << " " << custom_phrase; + } else { + oss << " @" << poss.str().substr(1); + } + lines.push_back(oss.str()); } + return EncodeBase(lines, symbol_table, keywords_id, keywords, boost_scores, threshold); } diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index a3189a20a..ba7d6448b 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -7,6 +7,7 @@ #include #include +#include "cppinyin/csrc/cppinyin.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "ssentencepiece/csrc/ssentencepiece.h" @@ -28,7 +29,7 @@ namespace sherpa_onnx { */ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, const SymbolTable &symbol_table, - const ssentencepiece::Ssentencepiece *bpe_encoder_, + const ssentencepiece::Ssentencepiece *bpe_encoder, std::vector> *hotwords_id); /* Encode the keywords in an input stream to be tokens ids. @@ -51,6 +52,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, * otherwise returns false. */ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, + const ssentencepiece::Ssentencepiece *bpe_encoder, + const cppinyin::PinyinEncoder *pinyin_encoder, std::vector> *keywords_id, std::vector *keywords, std::vector *boost_scores, From 656a267d6e1509715f4339ed27f2f014bfe69262 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 3 Jun 2024 07:35:44 +0800 Subject: [PATCH 3/4] pinyin encoder test pass --- .../csrc/keyword-spotter-transducer-impl.h | 10 +++--- sherpa-onnx/csrc/text2token-test.cc | 33 ++++++++++++++----- sherpa-onnx/csrc/utils.cc | 15 +++++++-- sherpa-onnx/csrc/utils.h | 3 +- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 2300839f3..832878525 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -19,6 +19,7 @@ #include "android/asset_manager_jni.h" #endif +#include "cppinyin/csrc/cppinyin.h" #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/keyword-spotter-impl.h" #include "sherpa-onnx/csrc/keyword-spotter.h" @@ -27,6 +28,7 @@ #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/transducer-keyword-decoder.h" #include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -118,8 +120,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::vector current_scores; std::vector current_thresholds; - if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores, - ¤t_thresholds)) { + if (!EncodeKeywords(is, "", sym_, nullptr, nullptr, ¤t_ids, + ¤t_kws, ¤t_scores, ¤t_thresholds)) { SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); return nullptr; } @@ -259,8 +261,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { private: void InitKeywords(std::istream &is) { - if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_, - &thresholds_)) { + if (!EncodeKeywords(is, "", sym_, nullptr, nullptr, &keywords_id_, + &keywords_, &boost_scores_, &thresholds_)) { SHERPA_ONNX_LOGE("Encode keywords failed."); exit(-1); } diff --git a/sherpa-onnx/csrc/text2token-test.cc b/sherpa-onnx/csrc/text2token-test.cc index 4fcc9565c..de33b70a0 100644 --- a/sherpa-onnx/csrc/text2token-test.cc +++ b/sherpa-onnx/csrc/text2token-test.cc @@ -6,6 +6,7 @@ #include #include +#include "cppinyin/csrc/cppinyin.h" #include "gtest/gtest.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/utils.h" @@ -215,13 +216,13 @@ TEST(TEXT2TOKEN, TEST_keyword_bpe) { TEST(TEXT2TOKEN, TEST_keyword_ppinyin) { std::ostringstream oss; - oss << dir << "/text2token/tokens_en.txt"; + oss << dir << "/text2token/tokens_pinyin.txt"; std::string tokens = oss.str(); oss.clear(); oss.str(""); - oss << dir << "/text2token/bpe_en.vocab"; - std::string bpe = oss.str(); - if (!std::ifstream(tokens).good() || !std::ifstream(bpe).good()) { + oss << dir << "/text2token/pinyin.raw"; + std::string lexicon = oss.str(); + if (!std::ifstream(tokens).good() || !std::ifstream(lexicon).good()) { SHERPA_ONNX_LOGE( "No test data found, skipping TEST_bpe()." "You can download the test data by: " @@ -231,18 +232,34 @@ TEST(TEXT2TOKEN, TEST_keyword_ppinyin) { } auto sym_table = SymbolTable(tokens); - auto bpe_processor = std::make_unique(bpe); + auto py_encoder = std::make_unique(lexicon); - std::string text = "HELLO WORLD\nI LOVE YOU"; + std::string text = "世界人民大团结 :2.0\n美国大选 #0.4\n中国很美 @中国很漂亮"; std::istringstream iss(text); std::vector> ids; + std::vector keywords; + std::vector scores; + std::vector thresholds; - auto r = EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids); + auto r = EncodeKeywords(iss, "ppinyin", sym_table, nullptr, py_encoder.get(), + &ids, &keywords, &scores, &thresholds); + + std::vector expected_keywords( + {"世界人民大团结", "美国大选", "中国很漂亮"}); + EXPECT_EQ(keywords, expected_keywords); + + std::vector expected_scores({2.0, 0, 0}); + EXPECT_EQ(scores, expected_scores); + + std::vector expected_thresholds({0, 0.4, 0}); + EXPECT_EQ(thresholds, expected_thresholds); std::vector> expected_ids( - {{22, 58, 24, 425}, {19, 370, 47}}); + {{13, 36, 24, 155, 39, 41, 58, 137, 53, 71, 77, 114, 24, 138}, + {58, 125, 43, 66, 53, 71, 48, 44}, + {10, 50, 43, 66, 75, 148, 58, 125}}); EXPECT_EQ(ids, expected_ids); } diff --git a/sherpa-onnx/csrc/utils.cc b/sherpa-onnx/csrc/utils.cc index 0016469f6..4711d592e 100644 --- a/sherpa-onnx/csrc/utils.cc +++ b/sherpa-onnx/csrc/utils.cc @@ -189,7 +189,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, nullptr); } -bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, +bool EncodeKeywords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, const ssentencepiece::Ssentencepiece *bpe_encoder, const cppinyin::PinyinEncoder *pinyin_encoder, std::vector> *keywords_id, @@ -231,7 +232,14 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, break; } } - phrase = oss.str().substr(1); + + phrase = oss.str(); + if (phrase.empty()) { + continue; + } else { + phrase = phrase.substr(1); + } + std::istringstream piss(phrase); oss.clear(); oss.str(""); @@ -254,7 +262,8 @@ bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, exit(-1); } std::vector pinyins; - pinyin_encoder->Encode(word, &pinyins); + pinyin_encoder->Encode(word, &pinyins, true /* tone */, + true /* partial */); for (const auto &pinyin : pinyins) { oss << " " << pinyin; } diff --git a/sherpa-onnx/csrc/utils.h b/sherpa-onnx/csrc/utils.h index d4da59d03..fdff06b3d 100644 --- a/sherpa-onnx/csrc/utils.h +++ b/sherpa-onnx/csrc/utils.h @@ -52,7 +52,8 @@ bool EncodeHotwords(std::istream &is, const std::string &modeling_unit, * @return If all the symbols from ``is`` are in the symbol_table, returns true * otherwise returns false. */ -bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table, +bool EncodeKeywords(std::istream &is, const std::string &modeling_unit, + const SymbolTable &symbol_table, const ssentencepiece::Ssentencepiece *bpe_encoder, const cppinyin::PinyinEncoder *pinyin_encoder, std::vector> *keywords_id, From ac52984d811143a25d9fed2982c058413c890204 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 4 Jun 2024 08:28:50 +0800 Subject: [PATCH 4/4] C++ works; jni; python --- cmake/cppinyin.cmake | 2 +- .../keyword-spotter-from-microphone.py | 35 ++++++++++++++++--- python-api-examples/keyword-spotter.py | 35 ++++++++++++++++--- .../csrc/keyword-spotter-transducer-impl.h | 35 +++++++++++++++++-- sherpa-onnx/csrc/keyword-spotter.cc | 7 ++-- sherpa-onnx/csrc/online-model-config.cc | 11 +++--- sherpa-onnx/jni/keyword-spotter.cc | 18 ++++++++++ sherpa-onnx/kotlin-api/KeywordSpotter.kt | 4 +++ sherpa-onnx/kotlin-api/OnlineRecognizer.kt | 1 + .../python/csrc/online-model-config.cc | 5 +-- .../python/sherpa_onnx/keyword_spotter.py | 17 +++++++++ .../python/tests/test_keyword_spotter.py | 10 ++++++ 12 files changed, 157 insertions(+), 23 deletions(-) diff --git a/cmake/cppinyin.cmake b/cmake/cppinyin.cmake index c2019c497..29dec3cd9 100644 --- a/cmake/cppinyin.cmake +++ b/cmake/cppinyin.cmake @@ -2,7 +2,7 @@ function(download_cppinyin) include(FetchContent) set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") - set(cppinyin_URL2 "https://hub.nauu.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") + set(cppinyin_URL2 "https://hub.nuaa.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz") set(cppinyin_HASH "SHA256=3659bc0c28d17d41ce932807c1cdc1da8c861e6acee969b5844d0d0a3c5ef78b") # If you don't have access to the Internet, diff --git a/python-api-examples/keyword-spotter-from-microphone.py b/python-api-examples/keyword-spotter-from-microphone.py index 65a59fca6..3df1d9b8a 100755 --- a/python-api-examples/keyword-spotter-from-microphone.py +++ b/python-api-examples/keyword-spotter-from-microphone.py @@ -95,15 +95,39 @@ def get_args(): """, ) + parser.add_argument( + "--modeling-unit", + type=str, + help="""The modeling unit of the model, valid values are bpe (for English model) + and ppinyin (For Chinese model). + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + help="""A simple format of bpe model, you can get it from the sentencepiece + generated folder. Used to tokenize the keywords into token ids. Used only + when modeling unit is bpe. + """, + ) + + parser.add_argument( + "--lexicon", + type=str, + help="""The lexicon used to tokenize the keywords into token ids. Used + only when modeling unit is ppinyin. + """, + ) + parser.add_argument( "--keywords-file", type=str, help=""" - The file containing keywords, one words/phrases per line, and for each - phrase the bpe/cjkchar/pinyin are separated by a space. For example: + The file containing keywords, one words/phrases per line. For example: - ▁HE LL O ▁WORLD - x iǎo ài t óng x ué + HELLO WORLD + 小爱同学 """, ) @@ -164,6 +188,9 @@ def main(): keywords_score=args.keywords_score, keywords_threshold=args.keywords_threshold, num_trailing_blanks=args.num_trailing_blanks, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, + lexicon=args.lexicon, provider=args.provider, ) diff --git a/python-api-examples/keyword-spotter.py b/python-api-examples/keyword-spotter.py index 1b1de77e3..a7fcf50db 100755 --- a/python-api-examples/keyword-spotter.py +++ b/python-api-examples/keyword-spotter.py @@ -80,15 +80,39 @@ def get_args(): """, ) + parser.add_argument( + "--modeling-unit", + type=str, + help="""The modeling unit of the model, valid values are bpe (for English model) + and ppinyin (For Chinese model). + """, + ) + + parser.add_argument( + "--bpe-vocab", + type=str, + help="""A simple format of bpe model, you can get it from the sentencepiece + generated folder. Used to tokenize the keywords into token ids. Used only + when modeling unit is bpe. + """, + ) + + parser.add_argument( + "--lexicon", + type=str, + help="""The lexicon used to tokenize the keywords into token ids. Used + only when modeling unit is ppinyin. + """, + ) + parser.add_argument( "--keywords-file", type=str, help=""" - The file containing keywords, one words/phrases per line, and for each - phrase the bpe/cjkchar/pinyin are separated by a space. For example: + The file containing keywords, one words/phrases per line. For example: - ▁HE LL O ▁WORLD - x iǎo ài t óng x ué + HELLO WORLD + 小爱同学 """, ) @@ -183,6 +207,9 @@ def main(): keywords_score=args.keywords_score, keywords_threshold=args.keywords_threshold, num_trailing_blanks=args.num_trailing_blanks, + modeling_unit=args.modeling_unit, + bpe_vocab=args.bpe_vocab, + lexicon=args.lexicon, provider=args.provider, ) diff --git a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h index 832878525..fca052e2e 100644 --- a/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h +++ b/sherpa-onnx/csrc/keyword-spotter-transducer-impl.h @@ -74,6 +74,17 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { unk_id_ = sym_[""]; } + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (config_.model_config.modeling_unit == "ppinyin" && + !config_.model_config.lexicon.empty()) { + pinyin_encoder_ = std::make_unique( + config_.model_config.lexicon); + } + model_->SetFeatureDim(config.feat_config.feature_dim); InitKeywords(); @@ -95,6 +106,19 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { model_->SetFeatureDim(config.feat_config.feature_dim); + if (!config_.model_config.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model_config.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + + if (config_.model_config.modeling_unit == "ppinyin" && + !config_.model_config.lexicon.empty()) { + auto buf = ReadFile(mgr, config_.model_config.lexicon); + std::istringstream iss(std::string(buf.begin(), buf.end())); + pinyin_encoder_ = std::make_unique(iss); + } + InitKeywords(mgr); decoder_ = std::make_unique( @@ -120,7 +144,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::vector current_scores; std::vector current_thresholds; - if (!EncodeKeywords(is, "", sym_, nullptr, nullptr, ¤t_ids, + if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), pinyin_encoder_.get(), ¤t_ids, ¤t_kws, ¤t_scores, ¤t_thresholds)) { SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str()); return nullptr; @@ -261,8 +286,10 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { private: void InitKeywords(std::istream &is) { - if (!EncodeKeywords(is, "", sym_, nullptr, nullptr, &keywords_id_, - &keywords_, &boost_scores_, &thresholds_)) { + if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_, + bpe_encoder_.get(), pinyin_encoder_.get(), + &keywords_id_, &keywords_, &boost_scores_, + &thresholds_)) { SHERPA_ONNX_LOGE("Encode keywords failed."); exit(-1); } @@ -325,6 +352,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl { std::vector thresholds_; std::vector keywords_; ContextGraphPtr keywords_graph_; + std::unique_ptr bpe_encoder_; + std::unique_ptr pinyin_encoder_; std::unique_ptr model_; std::unique_ptr decoder_; SymbolTable sym_; diff --git a/sherpa-onnx/csrc/keyword-spotter.cc b/sherpa-onnx/csrc/keyword-spotter.cc index 7e93d7a04..9cd4b50f9 100644 --- a/sherpa-onnx/csrc/keyword-spotter.cc +++ b/sherpa-onnx/csrc/keyword-spotter.cc @@ -83,10 +83,9 @@ void KeywordSpotterConfig::Register(ParseOptions *po) { "The acoustic threshold (probability) to trigger the keywords."); po->Register( "keywords-file", &keywords_file, - "The file containing keywords, one word/phrase per line, and for each" - "phrase the bpe/cjkchar are separated by a space. For example: " - "▁HE LL O ▁WORLD" - "你 好 世 界"); + "The file containing keywords, one word/phrase per line. For example: " + "HELLO WORLD" + "你好世界"); } bool KeywordSpotterConfig::Validate() const { diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index dff45692a..2fb97860e 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -32,11 +32,12 @@ void OnlineModelConfig::Register(ParseOptions *po) { po->Register("provider", &provider, "Specify a provider to use: cpu, cuda, coreml"); - po->Register("modeling-unit", &modeling_unit, - "The modeling unit of the model, commonly used units are bpe, " - "cjkchar, cjkchar+bpe, etc. Currently, it is needed only when " - "hotwords are provided, we need it to encode the hotwords into " - "token sequence."); + po->Register( + "modeling-unit", &modeling_unit, + "The modeling unit of the model, commonly used units are bpe, " + "cjkchar, cjkchar+bpe, ppinyin, etc. Currently, it is needed only when " + "hotwords are provided, we need it to encode the hotwords into " + "token sequence."); po->Register("bpe-vocab", &bpe_vocab, "The vocabulary generated by google's sentencepiece program. " diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc index 7a05b4855..0348894be 100644 --- a/sherpa-onnx/jni/keyword-spotter.cc +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -103,6 +103,24 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) { ans.model_config.model_type = p; env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.modeling_unit = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "lexicon", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.bpe_vocab = p; + env->ReleaseStringUTFChars(s, p); + return ans; } diff --git a/sherpa-onnx/kotlin-api/KeywordSpotter.kt b/sherpa-onnx/kotlin-api/KeywordSpotter.kt index 803762e51..8f6a88f3e 100644 --- a/sherpa-onnx/kotlin-api/KeywordSpotter.kt +++ b/sherpa-onnx/kotlin-api/KeywordSpotter.kt @@ -109,6 +109,8 @@ fun getKwsModelConfig(type: Int): OnlineModelConfig? { ), tokens = "$modelDir/tokens.txt", modelType = "zipformer2", + modelingUnit = "ppinyin", + lexicon = "$modelDir/pinyin.dict", ) } @@ -122,6 +124,8 @@ fun getKwsModelConfig(type: Int): OnlineModelConfig? { ), tokens = "$modelDir/tokens.txt", modelType = "zipformer2", + modelingUnit = "bpe", + bpeVocab = "$modelDir/bpe.vocab", ) } diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index e78fb6549..94f46591d 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -45,6 +45,7 @@ data class OnlineModelConfig( var modelType: String = "", var modelingUnit: String = "", var bpeVocab: String = "", + var lexicon: String = "", ) data class OnlineLMConfig( diff --git a/sherpa-onnx/python/csrc/online-model-config.cc b/sherpa-onnx/python/csrc/online-model-config.cc index d6db809bd..6e1b8e7e7 100644 --- a/sherpa-onnx/python/csrc/online-model-config.cc +++ b/sherpa-onnx/python/csrc/online-model-config.cc @@ -33,7 +33,7 @@ void PybindOnlineModelConfig(py::module *m) { const OnlineNeMoCtcModelConfig &, const std::string &, int32_t, int32_t, bool, const std::string &, const std::string &, const std::string &, - const std::string &>(), + const std::string &, const std::string &>(), py::arg("transducer") = OnlineTransducerModelConfig(), py::arg("paraformer") = OnlineParaformerModelConfig(), py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(), @@ -42,7 +42,7 @@ void PybindOnlineModelConfig(py::module *m) { py::arg("num_threads"), py::arg("warm_up") = 0, py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "", py::arg("modeling_unit") = "", - py::arg("bpe_vocab") = "") + py::arg("bpe_vocab") = "", py::arg("lexicon") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("wenet_ctc", &PyClass::wenet_ctc) @@ -55,6 +55,7 @@ void PybindOnlineModelConfig(py::module *m) { .def_readwrite("model_type", &PyClass::model_type) .def_readwrite("modeling_unit", &PyClass::modeling_unit) .def_readwrite("bpe_vocab", &PyClass::bpe_vocab) + .def_readwrite("lexicon", &PyClass::lexicon) .def("validate", &PyClass::Validate) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py index 218628ea9..2ac56ec89 100644 --- a/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py +++ b/sherpa-onnx/python/sherpa_onnx/keyword_spotter.py @@ -40,6 +40,9 @@ def __init__( keywords_score: float = 1.0, keywords_threshold: float = 0.25, num_trailing_blanks: int = 1, + modeling_unit: str = "cjkchar", + bpe_vocab: str = "", + lexicon: str = "", provider: str = "cpu", ): """ @@ -83,6 +86,17 @@ def __init__( The number of trailing blanks a keyword should be followed. Setting to a larger value (e.g. 8) when your keywords has overlapping tokens between each other. + modeling_unit: + The modeling unit of the model, commonly used units are bpe, ppinyin, etc. + We need it to encode the keywords into token sequence. + bpe_vocab: + The vocabulary generated by google's sentencepiece program. + It is a file has two columns, one is the token, the other is + the log probability, you can get it from the directory where + your bpe model is generated. Only used when the modeling unit is bpe. + lexicon: + The lexicon used to tokenize keywords into token sequences, Only used + when the modeling unit is pinyin or phone. provider: onnxruntime execution providers. Valid values are: cpu, cuda, coreml. """ @@ -103,6 +117,9 @@ def __init__( transducer=transducer_config, tokens=tokens, num_threads=num_threads, + modeling_unit=modeling_unit, + bpe_vocab=bpe_vocab, + lexicon=lexicon, provider=provider, ) diff --git a/sherpa-onnx/python/tests/test_keyword_spotter.py b/sherpa-onnx/python/tests/test_keyword_spotter.py index bdefa5d10..101c5bc1c 100755 --- a/sherpa-onnx/python/tests/test_keyword_spotter.py +++ b/sherpa-onnx/python/tests/test_keyword_spotter.py @@ -60,6 +60,9 @@ def test_zipformer_transducer_en(self): tokens = ( f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt" ) + bpe_vocab = ( + f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/bpe.vocab" + ) keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav" wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav" @@ -74,6 +77,8 @@ def test_zipformer_transducer_en(self): tokens=tokens, num_threads=1, keywords_file=keywords_file, + modeling_unit="bpe", + bpe_vocab=bpe_vocab, provider="cpu", ) streams = [] @@ -119,6 +124,9 @@ def test_zipformer_transducer_cn(self): tokens = ( f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" ) + lexicon = ( + f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/pinyin.dict" + ) keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav" @@ -134,6 +142,8 @@ def test_zipformer_transducer_cn(self): tokens=tokens, num_threads=1, keywords_file=keywords_file, + modeling_unit="ppinyin", + lexicon=lexicon, provider="cpu", ) streams = []