diff --git a/CMakeLists.txt b/CMakeLists.txt index 392c7c64..d59a019e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-ncnn) -set(SHERPA_NCNN_VERSION "2.0.7") +set(SHERPA_NCNN_VERSION "2.1.0") # Disable warning about # diff --git a/android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt b/android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt index 9391d1cd..594b5791 100644 --- a/android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt +++ b/android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt @@ -33,6 +33,8 @@ data class RecognizerConfig( var rule1MinTrailingSilence: Float = 2.4f, var rule2MinTrailingSilence: Float = 1.0f, var rule3MinUtteranceLength: Float = 30.0f, + var hotwordsFile: String = "", + var hotwordsScore: Float = 1.5f, ) class SherpaNcnn( diff --git a/c-api-examples/decode-file-c-api.c b/c-api-examples/decode-file-c-api.c index 3d183bcd..06ce858b 100644 --- a/c-api-examples/decode-file-c-api.c +++ b/c-api-examples/decode-file-c-api.c @@ -45,6 +45,8 @@ int32_t main(int32_t argc, char *argv[]) { return -1; } SherpaNcnnRecognizerConfig config; + memset(&config, 0, sizeof(config)); + config.model_config.tokens = argv[1]; config.model_config.encoder_param = argv[2]; config.model_config.encoder_bin = argv[3]; @@ -57,6 +59,7 @@ int32_t main(int32_t argc, char *argv[]) { if (argc >= 10 && atoi(argv[9]) > 0) { num_threads = atoi(argv[9]); } + config.model_config.num_threads = num_threads; config.model_config.use_vulkan_compute = 0; @@ -65,6 +68,7 @@ int32_t main(int32_t argc, char *argv[]) { if (argc >= 11) { config.decoder_config.decoding_method = argv[10]; } + config.decoder_config.num_active_paths = 4; config.enable_endpoint = 0; config.rule1_min_trailing_silence = 2.4; @@ -73,16 +77,14 @@ int32_t main(int32_t argc, char *argv[]) { config.feat_config.sampling_rate = 16000; config.feat_config.feature_dim = 80; - if(argc >= 12) { + if (argc >= 12) { config.hotwords_file = argv[11]; - } else { - config.hotwords_file = ""; } - if(argc == 13) { + + if (argc == 13) { config.hotwords_score = atof(argv[12]); - } else { - config.hotwords_score = 1.5; } + SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config); const char *wav_filename = argv[8]; diff --git a/scripts/dotnet/sherpa-ncnn.cs b/scripts/dotnet/sherpa-ncnn.cs index c6311d97..9a26ae9b 100644 --- a/scripts/dotnet/sherpa-ncnn.cs +++ b/scripts/dotnet/sherpa-ncnn.cs @@ -63,6 +63,11 @@ public struct OnlineRecognizerConfig public float Rule1MinTrailingSilence; public float Rule2MinTrailingSilence; public float Rule3MinUtteranceLength; + + [MarshalAs(UnmanagedType.LPStr)] + public string HotwordsFile; + + public float HotwordsScore; } // please see diff --git a/scripts/go/sherpa_ncnn.go b/scripts/go/sherpa_ncnn.go index b2a06496..ef155b61 100644 --- a/scripts/go/sherpa_ncnn.go +++ b/scripts/go/sherpa_ncnn.go @@ -84,6 +84,9 @@ type RecognizerConfig struct { Rule1MinTrailingSilence float32 Rule2MinTrailingSilence float32 Rule3MinUtteranceLength float32 + + HotwordsFile string + HotwordsScore float32 } // It contains the recognition result for a online stream. @@ -148,6 +151,11 @@ func NewRecognizer(config *RecognizerConfig) *Recognizer { c.rule2_min_trailing_silence = C.float(config.Rule2MinTrailingSilence) c.rule3_min_utterance_length = C.float(config.Rule3MinUtteranceLength) + c.hotwords_file = C.CString(config.HotwordsFile) + defer C.free(unsafe.Pointer(c.hotwords_file)) + + c.hotwords_score = C.float(config.HotwordsScore) + recognizer := &Recognizer{} recognizer.impl = C.CreateRecognizer(&c) diff --git a/sherpa-ncnn/c-api/c-api.cc b/sherpa-ncnn/c-api/c-api.cc index 9bf67b6b..0500451d 100644 --- a/sherpa-ncnn/c-api/c-api.cc +++ b/sherpa-ncnn/c-api/c-api.cc @@ -39,6 +39,8 @@ struct SherpaNcnnDisplay { std::unique_ptr impl; }; +#define SHERPA_NCNN_OR(x, y) (x ? x : y) + SherpaNcnnRecognizer *CreateRecognizer( const SherpaNcnnRecognizerConfig *in_config) { // model_config @@ -56,7 +58,7 @@ SherpaNcnnRecognizer *CreateRecognizer( config.model_config.use_vulkan_compute = in_config->model_config.use_vulkan_compute; - int32_t num_threads = in_config->model_config.num_threads; + int32_t num_threads = SHERPA_NCNN_OR(in_config->model_config.num_threads, 1); config.model_config.encoder_opt.num_threads = num_threads; config.model_config.decoder_opt.num_threads = num_threads; @@ -66,8 +68,9 @@ SherpaNcnnRecognizer *CreateRecognizer( config.decoder_config.method = in_config->decoder_config.decoding_method; config.decoder_config.num_active_paths = in_config->decoder_config.num_active_paths; - config.hotwords_file = in_config->hotwords_file; - config.hotwords_score = in_config->hotwords_score; + + config.hotwords_file = SHERPA_NCNN_OR(in_config->hotwords_file, ""); + config.hotwords_score = SHERPA_NCNN_OR(in_config->hotwords_score, 1.5); config.enable_endpoint = in_config->enable_endpoint; @@ -80,11 +83,17 @@ SherpaNcnnRecognizer *CreateRecognizer( config.endpoint_config.rule3.min_utterance_length = in_config->rule3_min_utterance_length; - config.feat_config.sampling_rate = in_config->feat_config.sampling_rate; - config.feat_config.feature_dim = in_config->feat_config.feature_dim; + config.feat_config.sampling_rate = + SHERPA_NCNN_OR(in_config->feat_config.sampling_rate, 16000); + + config.feat_config.feature_dim = + SHERPA_NCNN_OR(in_config->feat_config.feature_dim, 80); auto recognizer = std::make_unique(config); + if (!recognizer->GetModel()) { + NCNN_LOGE("Failed to create the recognizer! Please check your config: %s", + config.ToString().c_str()); return nullptr; } diff --git a/sherpa-ncnn/c-api/c-api.h b/sherpa-ncnn/c-api/c-api.h index e626cf33..9077c981 100644 --- a/sherpa-ncnn/c-api/c-api.h +++ b/sherpa-ncnn/c-api/c-api.h @@ -133,10 +133,10 @@ SHERPA_NCNN_API typedef struct SherpaNcnnRecognizerConfig { /// this value. /// Used only when enable_endpoint is not 0. float rule3_min_utterance_length; - - /// hotwords file, each line is a hotword which is segmented into char by space - /// if language is something like CJK, segment manually, - /// if language is something like English, segment by bpe model. + + /// hotwords file, each line is a hotword which is segmented into char by + /// space if language is something like CJK, segment manually, if language is + /// something like English, segment by bpe model. const char *hotwords_file; /// scale of hotwords, used only when hotwords_file is not empty diff --git a/sherpa-ncnn/csrc/decoder.h b/sherpa-ncnn/csrc/decoder.h index 904e19f3..22a17a79 100644 --- a/sherpa-ncnn/csrc/decoder.h +++ b/sherpa-ncnn/csrc/decoder.h @@ -59,7 +59,9 @@ struct DecoderResult { // used only for modified_beam_search Hypotheses hyps; }; + class Stream; + class Decoder { public: virtual ~Decoder() = default; @@ -88,7 +90,11 @@ class Decoder { * and there are no paddings. */ virtual void Decode(ncnn::Mat encoder_out, DecoderResult *result) = 0; - virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result){}; + + virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result) { + NCNN_LOGE("Please override it!"); + exit(-1); + } }; } // namespace sherpa_ncnn diff --git a/sherpa-ncnn/csrc/hypothesis.h b/sherpa-ncnn/csrc/hypothesis.h index 98badb98..c4349d01 100644 --- a/sherpa-ncnn/csrc/hypothesis.h +++ b/sherpa-ncnn/csrc/hypothesis.h @@ -24,6 +24,7 @@ #include #include #include + #include "sherpa-ncnn/csrc/context-graph.h" namespace sherpa_ncnn { @@ -43,7 +44,7 @@ struct Hypothesis { Hypothesis() = default; Hypothesis(const std::vector &ys, double log_prob, - const ContextState *context_state = nullptr) + const ContextState *context_state = nullptr) : ys(ys), log_prob(log_prob), context_state(context_state) {} // If two Hypotheses have the same `Key`, then they contain diff --git a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc index e908f47b..e72a6791 100644 --- a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc +++ b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc @@ -195,7 +195,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, result->num_trailing_blanks = hyp.num_trailing_blanks; } - void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result) { int32_t context_size = model_->ContextSize(); @@ -205,7 +204,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, std::vector prev = cur.GetTopK(num_active_paths_, true); cur.Clear(); - ncnn::Mat decoder_input = BuildDecoderInput(prev); ncnn::Mat decoder_out; if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size && @@ -218,14 +216,13 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, // decoder_out.w == decoder_dim // decoder_out.h == num_active_paths - ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t)); + ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t)); ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out); // joiner_out.w == vocab_size // joiner_out.h == num_active_paths LogSoftmax(&joiner_out); - float *p_joiner_out = joiner_out; for (int32_t i = 0; i != joiner_out.h; ++i) { @@ -255,8 +252,8 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, new_hyp.num_trailing_blanks = 0; new_hyp.timestamps.push_back(t + frame_offset); if (s != nullptr && s->GetContextGraph() != nullptr) { - auto context_res = s->GetContextGraph()->ForwardOneStep( - context_state, new_token); + auto context_res = + s->GetContextGraph()->ForwardOneStep(context_state, new_token); context_score = context_res.first; new_hyp.context_state = context_res.second; } diff --git a/sherpa-ncnn/csrc/recognizer.cc b/sherpa-ncnn/csrc/recognizer.cc index 4175c732..236a757e 100644 --- a/sherpa-ncnn/csrc/recognizer.cc +++ b/sherpa-ncnn/csrc/recognizer.cc @@ -18,7 +18,7 @@ */ #include "sherpa-ncnn/csrc/recognizer.h" -#include + #include #include #include @@ -29,6 +29,14 @@ #include "sherpa-ncnn/csrc/greedy-search-decoder.h" #include "sherpa-ncnn/csrc/modified-beam-search-decoder.h" +#if __ANDROID_API__ >= 9 +#include + +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#include "android/log.h" +#endif + namespace sherpa_ncnn { static RecognitionResult Convert(const DecoderResult &src, @@ -76,12 +84,10 @@ std::string RecognizerConfig::ToString() const { os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; os << "decoder_config=" << decoder_config.ToString() << ", "; - os << "max_active_paths=" << max_active_paths << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", "; os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; - os << "hotwrods_score=" << hotwords_score << ", "; - os << "decoding_method=\"" << decoding_method << "\")"; + os << "hotwrods_score=" << hotwords_score << ")"; return os.str(); } @@ -98,29 +104,9 @@ class Recognizer::Impl { } else if (config.decoder_config.method == "modified_beam_search") { decoder_ = std::make_unique( model_.get(), config.decoder_config.num_active_paths); - std::vector tmp; - /*each line in hotwords file is a string which is segmented by space*/ - std::ifstream file(config_.hotwords_file); - if (file) { - std::string line; - std::string word; - while (std::getline(file, line)) { - std::istringstream iss(line); - while(iss >> word){ - if (sym_.contains(word)) { - int number = sym_[word]; - tmp.push_back(number); - } else { - NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str()); - exit(-1); - } - } - hotwords_.push_back(tmp); - tmp.clear(); - } - } else { - NCNN_LOGE("open file failed: %s, hotwords will not be used", - config_.hotwords_file.c_str()); + + if (!config_.hotwords_file.empty()) { + InitHotwords(); } } else { NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str()); @@ -139,29 +125,9 @@ class Recognizer::Impl { } else if (config.decoder_config.method == "modified_beam_search") { decoder_ = std::make_unique( model_.get(), config.decoder_config.num_active_paths); - std::vector tmp; - /*each line in hotwords file is a string which is segmented by space*/ - std::ifstream file(config_.hotwords_file); - if (file) { - std::string line; - std::string word; - while (std::getline(file, line)) { - std::istringstream iss(line); - while(iss >> word){ - if (sym_.contains(word)) { - int number = sym_[word]; - tmp.push_back(number); - } else { - NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str()); - exit(-1); - } - } - hotwords_.push_back(tmp); - tmp.clear(); - } - } else { - NCNN_LOGE("open file failed: %s, hotwords will not be used", - config_.hotwords_file.c_str()); + + if (!config_.hotwords_file.empty()) { + InitHotwords(mgr); } } else { NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str()); @@ -171,27 +137,30 @@ class Recognizer::Impl { #endif std::unique_ptr CreateStream() const { - if(hotwords_.empty()) { + if (hotwords_.empty()) { auto stream = std::make_unique(config_.feat_config); stream->SetResult(decoder_->GetEmptyResult()); stream->SetStates(model_->GetEncoderInitStates()); return stream; } else { auto r = decoder_->GetEmptyResult(); + auto context_graph = std::make_shared(hotwords_, config_.hotwords_score); + auto stream = std::make_unique(config_.feat_config, context_graph); - if (config_.decoder_config.method == "modified_beam_search" && - nullptr != stream->GetContextGraph()) { - std::cout<<"create contexts stream"<GetContextGraph()) { // r.hyps has only one element. for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { it->second.context_state = stream->GetContextGraph()->Root(); } } + stream->SetResult(r); stream->SetStates(model_->GetEncoderInitStates()); + return stream; } } @@ -239,8 +208,8 @@ class Recognizer::Impl { void Reset(Stream *s) const { auto r = decoder_->GetEmptyResult(); - if (config_.decoding_method == "modified_beam_search" && - nullptr != s->GetContextGraph()) { + + if (s->GetContextGraph()) { for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { it->second.context_state = s->GetContextGraph()->Root(); } @@ -271,6 +240,60 @@ class Recognizer::Impl { const Model *GetModel() const { return model_.get(); } + private: +#if __ANDROID_API__ >= 9 + void InitHotwords(AAssetManager *mgr) { + AAsset *asset = AAssetManager_open(mgr, config_.hotwords_file.c_str(), + AASSET_MODE_BUFFER); + if (!asset) { + __android_log_print(ANDROID_LOG_FATAL, "sherpa-ncnn", + "hotwords_file: Load %s failed", + config_.hotwords_file.c_str()); + exit(-1); + } + + auto p = reinterpret_cast(AAsset_getBuffer(asset)); + size_t asset_length = AAsset_getLength(asset); + std::istrstream is(p, asset_length); + InitHotwords(is); + AAsset_close(asset); + } +#endif + + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + NCNN_LOGE("Open hotwords file failed: %s", config_.hotwords_file.c_str()); + exit(-1); + } + + InitHotwords(is); + } + + void InitHotwords(std::istream &is) { + std::vector tmp; + std::string line; + std::string word; + + while (std::getline(is, line)) { + std::istringstream iss(line); + while (iss >> word) { + if (sym_.contains(word)) { + int32_t number = sym_[word]; + tmp.push_back(number); + } else { + NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(), + line.c_str()); + exit(-1); + } + } + + hotwords_.push_back(std::move(tmp)); + } + } + private: RecognizerConfig config_; std::unique_ptr model_; @@ -294,7 +317,6 @@ std::unique_ptr Recognizer::CreateStream() const { return impl_->CreateStream(); } - bool Recognizer::IsReady(Stream *s) const { return impl_->IsReady(s); } void Recognizer::DecodeStream(Stream *s) const { impl_->DecodeStream(s); } diff --git a/sherpa-ncnn/csrc/recognizer.h b/sherpa-ncnn/csrc/recognizer.h index 97f95dd8..45a63b6b 100644 --- a/sherpa-ncnn/csrc/recognizer.h +++ b/sherpa-ncnn/csrc/recognizer.h @@ -31,6 +31,11 @@ #include "sherpa-ncnn/csrc/stream.h" #include "sherpa-ncnn/csrc/symbol-table.h" +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + namespace sherpa_ncnn { struct RecognitionResult { @@ -48,31 +53,27 @@ struct RecognizerConfig { FeatureExtractorConfig feat_config; ModelConfig model_config; DecoderConfig decoder_config; - std::string decoding_method; - std::string hotwords_file; EndpointConfig endpoint_config; bool enable_endpoint = false; - // used only for modified_beam_search - int32_t max_active_paths = 4; + + std::string hotwords_file; + /// used only for modified_beam_search float hotwords_score = 1.5; + RecognizerConfig() = default; RecognizerConfig(const FeatureExtractorConfig &feat_config, const ModelConfig &model_config, const DecoderConfig decoder_config, const EndpointConfig &endpoint_config, bool enable_endpoint, - const std::string &decoding_method, - const std::string &hotwords_file, - int32_t max_active_paths, float hotwords_score) + const std::string &hotwords_file, float hotwords_score) : feat_config(feat_config), model_config(model_config), decoder_config(decoder_config), endpoint_config(endpoint_config), enable_endpoint(enable_endpoint), - decoding_method(decoding_method), hotwords_file(hotwords_file), - max_active_paths(max_active_paths), hotwords_score(hotwords_score) {} std::string ToString() const; diff --git a/sherpa-ncnn/csrc/sherpa-ncnn.cc b/sherpa-ncnn/csrc/sherpa-ncnn.cc index b4e3cff9..ccc79ece 100644 --- a/sherpa-ncnn/csrc/sherpa-ncnn.cc +++ b/sherpa-ncnn/csrc/sherpa-ncnn.cc @@ -18,9 +18,10 @@ */ #include -#include + #include #include // NOLINT +#include #include #include "net.h" // NOLINT @@ -72,26 +73,24 @@ for a list of pre-trained models to download. config.decoder_config.method = method; } } - std::cout<<"decode method:"<= 12) { - config.hotwords_file = argv[11]; - } else { - config.hotwords_file = ""; + + if (argc >= 12) { + config.hotwords_file = argv[11]; } - if(argc == 13) { + + if (argc == 13) { config.hotwords_score = atof(argv[12]); - } else { - config.hotwords_file = 1.5; } + config.feat_config.sampling_rate = expected_sampling_rate; config.feat_config.feature_dim = 80; + std::cout << config.ToString() << "\n"; + sherpa_ncnn::Recognizer recognizer(config); std::string wav_filename = argv[8]; - std::cout << config.ToString() << "\n"; - bool is_ok = false; std::vector samples = sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate, &is_ok); diff --git a/sherpa-ncnn/csrc/stream.cc b/sherpa-ncnn/csrc/stream.cc index f4d1eec7..7b7af4d0 100644 --- a/sherpa-ncnn/csrc/stream.cc +++ b/sherpa-ncnn/csrc/stream.cc @@ -22,7 +22,8 @@ namespace sherpa_ncnn { class Stream::Impl { public: - explicit Impl(const FeatureExtractorConfig &config,ContextGraphPtr context_graph) + explicit Impl(const FeatureExtractorConfig &config, + ContextGraphPtr context_graph) : feat_extractor_(config), context_graph_(context_graph) {} void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { @@ -73,7 +74,8 @@ class Stream::Impl { std::vector states_; }; -Stream::Stream(const FeatureExtractorConfig &config, ContextGraphPtr context_graph) +Stream::Stream(const FeatureExtractorConfig &config, + ContextGraphPtr context_graph) : impl_(std::make_unique(config, context_graph)) {} Stream::~Stream() = default; @@ -113,5 +115,5 @@ std::vector &Stream::GetStates() { return impl_->GetStates(); } const ContextGraphPtr &Stream::GetContextGraph() const { return impl_->GetContextGraph(); - } +} } // namespace sherpa_ncnn diff --git a/sherpa-ncnn/csrc/stream.h b/sherpa-ncnn/csrc/stream.h index 87d536a6..9b3f4248 100644 --- a/sherpa-ncnn/csrc/stream.h +++ b/sherpa-ncnn/csrc/stream.h @@ -22,15 +22,15 @@ #include #include +#include "sherpa-ncnn/csrc/context-graph.h" #include "sherpa-ncnn/csrc/decoder.h" #include "sherpa-ncnn/csrc/features.h" -#include "sherpa-ncnn/csrc/context-graph.h" namespace sherpa_ncnn { class Stream { public: explicit Stream(const FeatureExtractorConfig &config = {}, - ContextGraphPtr context_graph = nullptr); + ContextGraphPtr context_graph = nullptr); ~Stream(); /** diff --git a/sherpa-ncnn/csrc/symbol-table.cc b/sherpa-ncnn/csrc/symbol-table.cc index 2a358f12..e3dbdcaa 100644 --- a/sherpa-ncnn/csrc/symbol-table.cc +++ b/sherpa-ncnn/csrc/symbol-table.cc @@ -23,12 +23,12 @@ #include #include - #if __ANDROID_API__ >= 9 +#include + #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #include "android/log.h" -#include #endif namespace sherpa_ncnn { diff --git a/sherpa-ncnn/jni/jni.cc b/sherpa-ncnn/jni/jni.cc index 00e7892c..74b39940 100644 --- a/sherpa-ncnn/jni/jni.cc +++ b/sherpa-ncnn/jni/jni.cc @@ -259,6 +259,15 @@ static RecognizerConfig ParseConfig(JNIEnv *env, jobject _config) { config.endpoint_config.rule3.min_utterance_length = env->GetFloatField(_config, fid); + fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;"); + jstring s = (jstring)env->GetObjectField(_config, fid); + const char *p = env->GetStringUTFChars(s, nullptr); + config.hotwords_file = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(cls, "hotwordsScore", "F"); + config.hotwords_score = env->GetFloatField(_config, fid); + NCNN_LOGE("------config------\n%s\n", config.ToString().c_str()); return config; diff --git a/sherpa-ncnn/python/csrc/recognizer.cc b/sherpa-ncnn/python/csrc/recognizer.cc index a5cc82f9..0e314e0f 100644 --- a/sherpa-ncnn/python/csrc/recognizer.cc +++ b/sherpa-ncnn/python/csrc/recognizer.cc @@ -63,16 +63,20 @@ static void PybindRecognizerConfig(py::module *m) { using PyClass = RecognizerConfig; py::class_(*m, "RecognizerConfig") .def(py::init(), + const DecoderConfig &, const EndpointConfig &, bool, + const std::string &, float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("decoder_config"), py::arg("endpoint_config"), - py::arg("enable_endpoint"), kRecognizerConfigInitDoc) + py::arg("enable_endpoint"), py::arg("hotwords_file") = "", + py::arg("hotwords_score") = 1.5, kRecognizerConfigInitDoc) .def("__str__", &PyClass::ToString) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("decoder_config", &PyClass::decoder_config) .def_readwrite("endpoint_config", &PyClass::endpoint_config) - .def_readwrite("enable_endpoint", &PyClass::enable_endpoint); + .def_readwrite("enable_endpoint", &PyClass::enable_endpoint) + .def_readwrite("hotwords_file", &PyClass::hotwords_file) + .def_readwrite("hotwords_score", &PyClass::hotwords_score); } void PybindRecognizer(py::module *m) { diff --git a/sherpa-ncnn/python/sherpa_ncnn/recognizer.py b/sherpa-ncnn/python/sherpa_ncnn/recognizer.py index ae074c45..299fe364 100644 --- a/sherpa-ncnn/python/sherpa_ncnn/recognizer.py +++ b/sherpa-ncnn/python/sherpa_ncnn/recognizer.py @@ -91,6 +91,8 @@ def __init__( rule2_min_trailing_silence: int = 1.2, rule3_min_utterance_length: int = 20, model_sample_rate: int = 16000, + hotwords_file: str = "", + hotwords_score: float = 1.5, ): """ Please refer to @@ -143,6 +145,14 @@ def __init__( is detected. model_sample_rate: Sample rate expected by the model + hotwords_file: + Optional. If not empty, it specifies the hotwords file. + Each line in the hotwords file is a hotword. A hotword + consists of words seperated by spaces. + Used only when decoding_method is modified_beam_search. + hotwords_score: + The scale applied to hotwords score. Used only + when hotwords_file is not empty. """ _assert_file_exists(tokens) _assert_file_exists(encoder_param) @@ -190,6 +200,8 @@ def __init__( decoder_config=decoder_config, endpoint_config=endpoint_config, enable_endpoint=enable_endpoint_detection, + hotwords_file=hotwords_file, + hotwords_score=hotwords_score, ) self.sample_rate = self.config.feat_config.sampling_rate diff --git a/swift-api-examples/SherpaNcnn.swift b/swift-api-examples/SherpaNcnn.swift index 77a70b91..990964e3 100644 --- a/swift-api-examples/SherpaNcnn.swift +++ b/swift-api-examples/SherpaNcnn.swift @@ -118,7 +118,9 @@ func sherpaNcnnRecognizerConfig( enableEndpoint: Bool = false, rule1MinTrailingSilence: Float = 2.4, rule2MinTrailingSilence: Float = 1.2, - rule3MinUtteranceLength: Float = 30 + rule3MinUtteranceLength: Float = 30, + hotwordsFile: String = "", + hotwordsScore: Float = 1.5 ) -> SherpaNcnnRecognizerConfig { return SherpaNcnnRecognizerConfig( feat_config: featConfig, @@ -127,7 +129,9 @@ func sherpaNcnnRecognizerConfig( enable_endpoint: enableEndpoint ? 1 : 0, rule1_min_trailing_silence: rule1MinTrailingSilence, rule2_min_trailing_silence: rule2MinTrailingSilence, - rule3_min_utterance_length: rule3MinUtteranceLength) + rule3_min_utterance_length: rule3MinUtteranceLength, + hotwords_file: toCPointer(hotwordsFile), + hotwords_score: hotwordsScore) } /// Wrapper for recognition result.