Skip to content

Commit

Permalink
C++ works; jni; python
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jun 4, 2024
1 parent 656a267 commit ac52984
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 23 deletions.
2 changes: 1 addition & 1 deletion cmake/cppinyin.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function(download_cppinyin)
include(FetchContent)

set(cppinyin_URL "https://github.com/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz")
set(cppinyin_URL2 "https://hub.nauu.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz")
set(cppinyin_URL2 "https://hub.nuaa.cf/pkufool/cppinyin/archive/refs/tags/v0.1.tar.gz")
set(cppinyin_HASH "SHA256=3659bc0c28d17d41ce932807c1cdc1da8c861e6acee969b5844d0d0a3c5ef78b")

# If you don't have access to the Internet,
Expand Down
35 changes: 31 additions & 4 deletions python-api-examples/keyword-spotter-from-microphone.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,39 @@ def get_args():
""",
)

parser.add_argument(
"--modeling-unit",
type=str,
help="""The modeling unit of the model, valid values are bpe (for English model)
and ppinyin (For Chinese model).
""",
)

parser.add_argument(
"--bpe-vocab",
type=str,
help="""A simple format of bpe model, you can get it from the sentencepiece
generated folder. Used to tokenize the keywords into token ids. Used only
when modeling unit is bpe.
""",
)

parser.add_argument(
"--lexicon",
type=str,
help="""The lexicon used to tokenize the keywords into token ids. Used
only when modeling unit is ppinyin.
""",
)

parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
The file containing keywords, one words/phrases per line. For example:
▁HE LL O ▁WORLD
x iǎo ài t óng x ué
HELLO WORLD
小爱同学
""",
)

Expand Down Expand Up @@ -164,6 +188,9 @@ def main():
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_trailing_blanks=args.num_trailing_blanks,
modeling_unit=args.modeling_unit,
bpe_vocab=args.bpe_vocab,
lexicon=args.lexicon,
provider=args.provider,
)

Expand Down
35 changes: 31 additions & 4 deletions python-api-examples/keyword-spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,39 @@ def get_args():
""",
)

parser.add_argument(
"--modeling-unit",
type=str,
help="""The modeling unit of the model, valid values are bpe (for English model)
and ppinyin (For Chinese model).
""",
)

parser.add_argument(
"--bpe-vocab",
type=str,
help="""A simple format of bpe model, you can get it from the sentencepiece
generated folder. Used to tokenize the keywords into token ids. Used only
when modeling unit is bpe.
""",
)

parser.add_argument(
"--lexicon",
type=str,
help="""The lexicon used to tokenize the keywords into token ids. Used
only when modeling unit is ppinyin.
""",
)

parser.add_argument(
"--keywords-file",
type=str,
help="""
The file containing keywords, one words/phrases per line, and for each
phrase the bpe/cjkchar/pinyin are separated by a space. For example:
The file containing keywords, one words/phrases per line. For example:
▁HE LL O ▁WORLD
x iǎo ài t óng x ué
HELLO WORLD
小爱同学
""",
)

Expand Down Expand Up @@ -183,6 +207,9 @@ def main():
keywords_score=args.keywords_score,
keywords_threshold=args.keywords_threshold,
num_trailing_blanks=args.num_trailing_blanks,
modeling_unit=args.modeling_unit,
bpe_vocab=args.bpe_vocab,
lexicon=args.lexicon,
provider=args.provider,
)

Expand Down
35 changes: 32 additions & 3 deletions sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}

if (!config_.model_config.bpe_vocab.empty()) {
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
config_.model_config.bpe_vocab);
}

if (config_.model_config.modeling_unit == "ppinyin" &&
!config_.model_config.lexicon.empty()) {
pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(
config_.model_config.lexicon);
}

model_->SetFeatureDim(config.feat_config.feature_dim);

InitKeywords();
Expand All @@ -95,6 +106,19 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

model_->SetFeatureDim(config.feat_config.feature_dim);

if (!config_.model_config.bpe_vocab.empty()) {
auto buf = ReadFile(mgr, config_.model_config.bpe_vocab);
std::istringstream iss(std::string(buf.begin(), buf.end()));
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
}

if (config_.model_config.modeling_unit == "ppinyin" &&
!config_.model_config.lexicon.empty()) {
auto buf = ReadFile(mgr, config_.model_config.lexicon);
std::istringstream iss(std::string(buf.begin(), buf.end()));
pinyin_encoder_ = std::make_unique<cppinyin::PinyinEncoder>(iss);
}

InitKeywords(mgr);

decoder_ = std::make_unique<TransducerKeywordDecoder>(
Expand All @@ -120,7 +144,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
std::vector<float> current_scores;
std::vector<float> current_thresholds;

if (!EncodeKeywords(is, "", sym_, nullptr, nullptr, &current_ids,
if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), pinyin_encoder_.get(), &current_ids,
&current_kws, &current_scores, &current_thresholds)) {
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
return nullptr;
Expand Down Expand Up @@ -261,8 +286,10 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {

private:
void InitKeywords(std::istream &is) {
if (!EncodeKeywords(is, "", sym_, nullptr, nullptr, &keywords_id_,
&keywords_, &boost_scores_, &thresholds_)) {
if (!EncodeKeywords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), pinyin_encoder_.get(),
&keywords_id_, &keywords_, &boost_scores_,
&thresholds_)) {
SHERPA_ONNX_LOGE("Encode keywords failed.");
exit(-1);
}
Expand Down Expand Up @@ -325,6 +352,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
std::vector<float> thresholds_;
std::vector<std::string> keywords_;
ContextGraphPtr keywords_graph_;
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
std::unique_ptr<cppinyin::PinyinEncoder> pinyin_encoder_;
std::unique_ptr<OnlineTransducerModel> model_;
std::unique_ptr<TransducerKeywordDecoder> decoder_;
SymbolTable sym_;
Expand Down
7 changes: 3 additions & 4 deletions sherpa-onnx/csrc/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ void KeywordSpotterConfig::Register(ParseOptions *po) {
"The acoustic threshold (probability) to trigger the keywords.");
po->Register(
"keywords-file", &keywords_file,
"The file containing keywords, one word/phrase per line, and for each"
"phrase the bpe/cjkchar are separated by a space. For example: "
"▁HE LL O ▁WORLD"
"你 好 世 界");
"The file containing keywords, one word/phrase per line. For example: "
"HELLO WORLD"
"你好世界");
}

bool KeywordSpotterConfig::Validate() const {
Expand Down
11 changes: 6 additions & 5 deletions sherpa-onnx/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");

po->Register("modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");
po->Register(
"modeling-unit", &modeling_unit,
"The modeling unit of the model, commonly used units are bpe, "
"cjkchar, cjkchar+bpe, ppinyin, etc. Currently, it is needed only when "
"hotwords are provided, we need it to encode the hotwords into "
"token sequence.");

po->Register("bpe-vocab", &bpe_vocab,
"The vocabulary generated by google's sentencepiece program. "
Expand Down
18 changes: 18 additions & 0 deletions sherpa-onnx/jni/keyword-spotter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "modelingUnit", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.modeling_unit = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);

fid = env->GetFieldID(model_config_cls, "lexicon", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);

return ans;
}

Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/kotlin-api/KeywordSpotter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ fun getKwsModelConfig(type: Int): OnlineModelConfig? {
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer2",
modelingUnit = "ppinyin",
lexicon = "$modelDir/pinyin.dict",
)
}

Expand All @@ -122,6 +124,8 @@ fun getKwsModelConfig(type: Int): OnlineModelConfig? {
),
tokens = "$modelDir/tokens.txt",
modelType = "zipformer2",
modelingUnit = "bpe",
bpeVocab = "$modelDir/bpe.vocab",
)
}

Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ data class OnlineModelConfig(
var modelType: String = "",
var modelingUnit: String = "",
var bpeVocab: String = "",
var lexicon: String = "",
)

data class OnlineLMConfig(
Expand Down
5 changes: 3 additions & 2 deletions sherpa-onnx/python/csrc/online-model-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void PybindOnlineModelConfig(py::module *m) {
const OnlineNeMoCtcModelConfig &, const std::string &,
int32_t, int32_t, bool, const std::string &,
const std::string &, const std::string &,
const std::string &>(),
const std::string &, const std::string &>(),
py::arg("transducer") = OnlineTransducerModelConfig(),
py::arg("paraformer") = OnlineParaformerModelConfig(),
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
Expand All @@ -42,7 +42,7 @@ void PybindOnlineModelConfig(py::module *m) {
py::arg("num_threads"), py::arg("warm_up") = 0,
py::arg("debug") = false, py::arg("provider") = "cpu",
py::arg("model_type") = "", py::arg("modeling_unit") = "",
py::arg("bpe_vocab") = "")
py::arg("bpe_vocab") = "", py::arg("lexicon") = "")
.def_readwrite("transducer", &PyClass::transducer)
.def_readwrite("paraformer", &PyClass::paraformer)
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
Expand All @@ -55,6 +55,7 @@ void PybindOnlineModelConfig(py::module *m) {
.def_readwrite("model_type", &PyClass::model_type)
.def_readwrite("modeling_unit", &PyClass::modeling_unit)
.def_readwrite("bpe_vocab", &PyClass::bpe_vocab)
.def_readwrite("lexicon", &PyClass::lexicon)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
Expand Down
17 changes: 17 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/keyword_spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __init__(
keywords_score: float = 1.0,
keywords_threshold: float = 0.25,
num_trailing_blanks: int = 1,
modeling_unit: str = "cjkchar",
bpe_vocab: str = "",
lexicon: str = "",
provider: str = "cpu",
):
"""
Expand Down Expand Up @@ -83,6 +86,17 @@ def __init__(
The number of trailing blanks a keyword should be followed. Setting
to a larger value (e.g. 8) when your keywords has overlapping tokens
between each other.
modeling_unit:
The modeling unit of the model, commonly used units are bpe, ppinyin, etc.
We need it to encode the keywords into token sequence.
bpe_vocab:
The vocabulary generated by google's sentencepiece program.
It is a file has two columns, one is the token, the other is
the log probability, you can get it from the directory where
your bpe model is generated. Only used when the modeling unit is bpe.
lexicon:
The lexicon used to tokenize keywords into token sequences, Only used
when the modeling unit is pinyin or phone.
provider:
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
"""
Expand All @@ -103,6 +117,9 @@ def __init__(
transducer=transducer_config,
tokens=tokens,
num_threads=num_threads,
modeling_unit=modeling_unit,
bpe_vocab=bpe_vocab,
lexicon=lexicon,
provider=provider,
)

Expand Down
10 changes: 10 additions & 0 deletions sherpa-onnx/python/tests/test_keyword_spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def test_zipformer_transducer_en(self):
tokens = (
f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/tokens.txt"
)
bpe_vocab = (
f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/bpe.vocab"
)
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
wave0 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/0.wav"
wave1 = f"{d}/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/test_wavs/1.wav"
Expand All @@ -74,6 +77,8 @@ def test_zipformer_transducer_en(self):
tokens=tokens,
num_threads=1,
keywords_file=keywords_file,
modeling_unit="bpe",
bpe_vocab=bpe_vocab,
provider="cpu",
)
streams = []
Expand Down Expand Up @@ -119,6 +124,9 @@ def test_zipformer_transducer_cn(self):
tokens = (
f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
)
lexicon = (
f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/pinyin.dict"
)
keywords_file = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
wave0 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"
wave1 = f"{d}/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/4.wav"
Expand All @@ -134,6 +142,8 @@ def test_zipformer_transducer_cn(self):
tokens=tokens,
num_threads=1,
keywords_file=keywords_file,
modeling_unit="ppinyin",
lexicon=lexicon,
provider="cpu",
)
streams = []
Expand Down

0 comments on commit ac52984

Please sign in to comment.