diff --git a/.github/scripts/run-test.sh b/.github/scripts/run-test.sh index 5b714c53..8e3350fc 100755 --- a/.github/scripts/run-test.sh +++ b/.github/scripts/run-test.sh @@ -544,3 +544,48 @@ for wave in ${waves[@]}; do done rm -rf $repo + +log "------------------------------------------------------------" +log "Run hotwords test (Chinese)" +log "------------------------------------------------------------" +repo_url=https://huggingface.co/HalFTeen/sherpa-ncnn-hotwords-test/ +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "encoder_jit_trace-pnnx.ncnn.bin" +git lfs pull --include "decoder_jit_trace-pnnx.ncnn.bin" +git lfs pull --include "joiner_jit_trace-pnnx.ncnn.bin" +popd + + +log "----test $m without hotwords---" +time $EXE \ + $repo/tokens.txt \ + $repo/encoder_jit_trace-pnnx.ncnn.param \ + $repo/encoder_jit_trace-pnnx.ncnn.bin \ + $repo/decoder_jit_trace-pnnx.ncnn.param \ + $repo/decoder_jit_trace-pnnx.ncnn.bin \ + $repo/joiner_jit_trace-pnnx.ncnn.param \ + $repo/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/hotwords.wav \ + 4 \ + modified_beam_search + + +log "----test $m with hotwords---" +time $EXE \ + $repo/tokens.txt \ + $repo/encoder_jit_trace-pnnx.ncnn.param \ + $repo/encoder_jit_trace-pnnx.ncnn.bin \ + $repo/decoder_jit_trace-pnnx.ncnn.param \ + $repo/decoder_jit_trace-pnnx.ncnn.bin \ + $repo/joiner_jit_trace-pnnx.ncnn.param \ + $repo/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/hotwords.wav \ + 4 \ + modified_beam_search \ + $repo/hotwords.txt 1.6 + +rm -rf $repo \ No newline at end of file diff --git a/c-api-examples/decode-file-c-api.c b/c-api-examples/decode-file-c-api.c index ddad88aa..3d183bcd 100644 --- a/c-api-examples/decode-file-c-api.c +++ b/c-api-examples/decode-file-c-api.c @@ -40,7 +40,7 @@ const char *kUsage = "for a list of pre-trained models to download.\n"; int32_t main(int32_t argc, char *argv[]) { - if (argc < 9 || argc > 11) { + if (argc < 9 || argc > 13) { fprintf(stderr, "%s\n", kUsage); return -1; } @@ -62,7 +62,7 @@ int32_t main(int32_t argc, char *argv[]) { config.decoder_config.decoding_method = "greedy_search"; - if (argc == 11) { + if (argc >= 11) { config.decoder_config.decoding_method = argv[10]; } config.decoder_config.num_active_paths = 4; @@ -73,7 +73,16 @@ int32_t main(int32_t argc, char *argv[]) { config.feat_config.sampling_rate = 16000; config.feat_config.feature_dim = 80; - + if(argc >= 12) { + config.hotwords_file = argv[11]; + } else { + config.hotwords_file = ""; + } + if(argc == 13) { + config.hotwords_score = atof(argv[12]); + } else { + config.hotwords_score = 1.5; + } SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config); const char *wav_filename = argv[8]; @@ -92,7 +101,6 @@ int32_t main(int32_t argc, char *argv[]) { int16_t buffer[N]; float samples[N]; - SherpaNcnnStream *s = CreateStream(recognizer); SherpaNcnnDisplay *display = CreateDisplay(50); diff --git a/sherpa-ncnn/c-api/c-api.cc b/sherpa-ncnn/c-api/c-api.cc index 743bfa75..9bf67b6b 100644 --- a/sherpa-ncnn/c-api/c-api.cc +++ b/sherpa-ncnn/c-api/c-api.cc @@ -66,6 +66,8 @@ SherpaNcnnRecognizer *CreateRecognizer( config.decoder_config.method = in_config->decoder_config.decoding_method; config.decoder_config.num_active_paths = in_config->decoder_config.num_active_paths; + config.hotwords_file = in_config->hotwords_file; + config.hotwords_score = in_config->hotwords_score; config.enable_endpoint = in_config->enable_endpoint; diff --git a/sherpa-ncnn/c-api/c-api.h b/sherpa-ncnn/c-api/c-api.h index e0b600ac..e626cf33 100644 --- a/sherpa-ncnn/c-api/c-api.h +++ b/sherpa-ncnn/c-api/c-api.h @@ -133,6 +133,14 @@ SHERPA_NCNN_API typedef struct SherpaNcnnRecognizerConfig { /// this value. /// Used only when enable_endpoint is not 0. float rule3_min_utterance_length; + + /// hotwords file, each line is a hotword which is segmented into char by space + /// if language is something like CJK, segment manually, + /// if language is something like English, segment by bpe model. + const char *hotwords_file; + + /// scale of hotwords, used only when hotwords_file is not empty + float hotwords_score; } SherpaNcnnRecognizerConfig; SHERPA_NCNN_API typedef struct SherpaNcnnResult { diff --git a/sherpa-ncnn/csrc/CMakeLists.txt b/sherpa-ncnn/csrc/CMakeLists.txt index e5d54265..41ed79f0 100644 --- a/sherpa-ncnn/csrc/CMakeLists.txt +++ b/sherpa-ncnn/csrc/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${CMAKE_SOURCE_DIR}) set(sherpa_ncnn_core_srcs + context-graph.cc conv-emformer-model.cc decoder.cc endpoint.cc diff --git a/sherpa-ncnn/csrc/context-graph.cc b/sherpa-ncnn/csrc/context-graph.cc new file mode 100644 index 00000000..78c08bc6 --- /dev/null +++ b/sherpa-ncnn/csrc/context-graph.cc @@ -0,0 +1,95 @@ +// sherpa-ncnn/csrc/context-graph.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-ncnn/csrc/context-graph.h" + +#include +#include +#include + +namespace sherpa_ncnn { +void ContextGraph::Build( + const std::vector> &token_ids) const { + for (int32_t i = 0; i < token_ids.size(); ++i) { + auto node = root_.get(); + for (int32_t j = 0; j < token_ids[i].size(); ++j) { + int32_t token = token_ids[i][j]; + if (0 == node->next.count(token)) { + bool is_end = j == token_ids[i].size() - 1; + node->next[token] = std::make_unique( + token, context_score_, node->node_score + context_score_, + is_end ? node->node_score + context_score_ : 0, is_end); + } + node = node->next[token].get(); + } + } + FillFailOutput(); +} + +std::pair ContextGraph::ForwardOneStep( + const ContextState *state, int32_t token) const { + const ContextState *node; + float score; + if (1 == state->next.count(token)) { + node = state->next.at(token).get(); + score = node->token_score; + } else { + node = state->fail; + while (0 == node->next.count(token)) { + node = node->fail; + if (-1 == node->token) break; // root + } + if (1 == node->next.count(token)) { + node = node->next.at(token).get(); + } + score = node->node_score - state->node_score; + } + return std::make_pair(score + node->output_score, node); +} + +std::pair ContextGraph::Finalize( + const ContextState *state) const { + float score = -state->node_score; + return std::make_pair(score, root_.get()); +} + +void ContextGraph::FillFailOutput() const { + std::queue node_queue; + for (auto &kv : root_->next) { + kv.second->fail = root_.get(); + node_queue.push(kv.second.get()); + } + while (!node_queue.empty()) { + auto current_node = node_queue.front(); + node_queue.pop(); + for (auto &kv : current_node->next) { + auto fail = current_node->fail; + if (1 == fail->next.count(kv.first)) { + fail = fail->next.at(kv.first).get(); + } else { + fail = fail->fail; + while (0 == fail->next.count(kv.first)) { + fail = fail->fail; + if (-1 == fail->token) break; + } + if (1 == fail->next.count(kv.first)) + fail = fail->next.at(kv.first).get(); + } + kv.second->fail = fail; + // fill the output arc + auto output = fail; + while (!output->is_end) { + output = output->fail; + if (-1 == output->token) { + output = nullptr; + break; + } + } + kv.second->output = output; + kv.second->output_score += output == nullptr ? 0 : output->output_score; + node_queue.push(kv.second.get()); + } + } +} +} // namespace sherpa_ncnn diff --git a/sherpa-ncnn/csrc/context-graph.h b/sherpa-ncnn/csrc/context-graph.h new file mode 100644 index 00000000..0002fa52 --- /dev/null +++ b/sherpa-ncnn/csrc/context-graph.h @@ -0,0 +1,65 @@ +// sherpa-ncnn/csrc/context-graph.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ +#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ + +#include +#include +#include +#include + + +namespace sherpa_ncnn { + +class ContextGraph; +using ContextGraphPtr = std::shared_ptr; + +struct ContextState { + int32_t token; + float token_score; + float node_score; + float output_score; + bool is_end; + std::unordered_map> next; + const ContextState *fail = nullptr; + const ContextState *output = nullptr; + + ContextState() = default; + ContextState(int32_t token, float token_score, float node_score, + float output_score, bool is_end) + : token(token), + token_score(token_score), + node_score(node_score), + output_score(output_score), + is_end(is_end) {} +}; + +class ContextGraph { + public: + ContextGraph() = default; + ContextGraph(const std::vector> &token_ids, + float hotwords_score) + : context_score_(hotwords_score) { + root_ = std::make_unique(-1, 0, 0, 0, false); + root_->fail = root_.get(); + Build(token_ids); + } + + std::pair ForwardOneStep( + const ContextState *state, int32_t token_id) const; + std::pair Finalize( + const ContextState *state) const; + + const ContextState *Root() const { return root_.get(); } + + private: + float context_score_; + std::unique_ptr root_; + void Build(const std::vector> &token_ids) const; + void FillFailOutput() const; +}; + +} // namespace sherpa_ncnn +#endif // SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_ diff --git a/sherpa-ncnn/csrc/decoder.h b/sherpa-ncnn/csrc/decoder.h index e403ad88..904e19f3 100644 --- a/sherpa-ncnn/csrc/decoder.h +++ b/sherpa-ncnn/csrc/decoder.h @@ -59,7 +59,7 @@ struct DecoderResult { // used only for modified_beam_search Hypotheses hyps; }; - +class Stream; class Decoder { public: virtual ~Decoder() = default; @@ -88,6 +88,7 @@ class Decoder { * and there are no paddings. */ virtual void Decode(ncnn::Mat encoder_out, DecoderResult *result) = 0; + virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result){}; }; } // namespace sherpa_ncnn diff --git a/sherpa-ncnn/csrc/hypothesis.h b/sherpa-ncnn/csrc/hypothesis.h index f60fd813..98badb98 100644 --- a/sherpa-ncnn/csrc/hypothesis.h +++ b/sherpa-ncnn/csrc/hypothesis.h @@ -24,6 +24,7 @@ #include #include #include +#include "sherpa-ncnn/csrc/context-graph.h" namespace sherpa_ncnn { @@ -37,12 +38,13 @@ struct Hypothesis { // The total score of ys in log space. double log_prob = 0; - + const ContextState *context_state; int32_t num_trailing_blanks = 0; Hypothesis() = default; - Hypothesis(const std::vector &ys, double log_prob) - : ys(ys), log_prob(log_prob) {} + Hypothesis(const std::vector &ys, double log_prob, + const ContextState *context_state = nullptr) + : ys(ys), log_prob(log_prob), context_state(context_state) {} // If two Hypotheses have the same `Key`, then they contain // the same token sequence. @@ -104,6 +106,8 @@ class Hypotheses { const auto begin() const { return hyps_dict_.begin(); } const auto end() const { return hyps_dict_.end(); } + auto begin() { return hyps_dict_.begin(); } + auto end() { return hyps_dict_.end(); } void Clear() { hyps_dict_.clear(); } diff --git a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc index 441a7b89..e908f47b 100644 --- a/sherpa-ncnn/csrc/modified-beam-search-decoder.cc +++ b/sherpa-ncnn/csrc/modified-beam-search-decoder.cc @@ -37,7 +37,7 @@ DecoderResult ModifiedBeamSearchDecoder::GetEmptyResult() const { Hypotheses blank_hyp({{blanks, 0}}); r.hyps = std::move(blank_hyp); - + r.tokens = std::move(blanks); return r; } @@ -195,4 +195,91 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, result->num_trailing_blanks = hyp.num_trailing_blanks; } + +void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s, + DecoderResult *result) { + int32_t context_size = model_->ContextSize(); + Hypotheses cur = std::move(result->hyps); + /* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */ + for (int32_t t = 0; t != encoder_out.h; ++t) { + std::vector prev = cur.GetTopK(num_active_paths_, true); + cur.Clear(); + + + ncnn::Mat decoder_input = BuildDecoderInput(prev); + ncnn::Mat decoder_out; + if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size && + !result->decoder_out.empty()) { + // When an endpoint is detected, we keep the decoder_out + decoder_out = result->decoder_out; + } else { + decoder_out = RunDecoder2D(model_, decoder_input); + } + + // decoder_out.w == decoder_dim + // decoder_out.h == num_active_paths + ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t)); + + ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out); + // joiner_out.w == vocab_size + // joiner_out.h == num_active_paths + LogSoftmax(&joiner_out); + + + float *p_joiner_out = joiner_out; + + for (int32_t i = 0; i != joiner_out.h; ++i) { + float prev_log_prob = prev[i].log_prob; + for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) { + *p_joiner_out += prev_log_prob; + } + } + + auto topk = TopkIndex(static_cast(joiner_out), + joiner_out.w * joiner_out.h, num_active_paths_); + + int32_t frame_offset = result->frame_offset; + for (auto i : topk) { + int32_t hyp_index = i / joiner_out.w; + int32_t new_token = i % joiner_out.w; + + const float *p = joiner_out.row(hyp_index); + + Hypothesis new_hyp = prev[hyp_index]; + // const float prev_lm_log_prob = new_hyp.lm_log_prob; + float context_score = 0; + auto context_state = new_hyp.context_state; + // blank id is fixed to 0 + if (new_token != 0) { + new_hyp.ys.push_back(new_token); + new_hyp.num_trailing_blanks = 0; + new_hyp.timestamps.push_back(t + frame_offset); + if (s != nullptr && s->GetContextGraph() != nullptr) { + auto context_res = s->GetContextGraph()->ForwardOneStep( + context_state, new_token); + context_score = context_res.first; + new_hyp.context_state = context_res.second; + } + } else { + ++new_hyp.num_trailing_blanks; + } + // We have already added prev[hyp_index].log_prob to p[new_token] + new_hyp.log_prob = p[new_token] + context_score; + + cur.Add(std::move(new_hyp)); + } + } + + result->hyps = std::move(cur); + result->frame_offset += encoder_out.h; + auto hyp = result->hyps.GetMostProbable(true); + + // set decoder_out in case of endpointing + ncnn::Mat decoder_input = BuildDecoderInput({hyp}); + result->decoder_out = model_->RunDecoder(decoder_input); + + result->tokens = std::move(hyp.ys); + result->num_trailing_blanks = hyp.num_trailing_blanks; +} + } // namespace sherpa_ncnn diff --git a/sherpa-ncnn/csrc/modified-beam-search-decoder.h b/sherpa-ncnn/csrc/modified-beam-search-decoder.h index f0911643..e84a0c49 100644 --- a/sherpa-ncnn/csrc/modified-beam-search-decoder.h +++ b/sherpa-ncnn/csrc/modified-beam-search-decoder.h @@ -25,6 +25,8 @@ #include "mat.h" // NOLINT #include "sherpa-ncnn/csrc/decoder.h" #include "sherpa-ncnn/csrc/model.h" +#include "sherpa-ncnn/csrc/stream.h" +#include "sherpa-ncnn/csrc/context-graph.h" namespace sherpa_ncnn { @@ -38,6 +40,7 @@ class ModifiedBeamSearchDecoder : public Decoder { void StripLeadingBlanks(DecoderResult *r) const override; void Decode(ncnn::Mat encoder_out, DecoderResult *result) override; + void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result) override; private: ncnn::Mat BuildDecoderInput(const std::vector &hyps) const; diff --git a/sherpa-ncnn/csrc/recognizer.cc b/sherpa-ncnn/csrc/recognizer.cc index 83f79ee1..4175c732 100644 --- a/sherpa-ncnn/csrc/recognizer.cc +++ b/sherpa-ncnn/csrc/recognizer.cc @@ -18,7 +18,8 @@ */ #include "sherpa-ncnn/csrc/recognizer.h" - +#include +#include #include #include #include @@ -75,8 +76,12 @@ std::string RecognizerConfig::ToString() const { os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; os << "decoder_config=" << decoder_config.ToString() << ", "; + os << "max_active_paths=" << max_active_paths << ", "; os << "endpoint_config=" << endpoint_config.ToString() << ", "; - os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")"; + os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", "; + os << "hotwords_file=\"" << hotwords_file << "\", "; + os << "hotwrods_score=" << hotwords_score << ", "; + os << "decoding_method=\"" << decoding_method << "\")"; return os.str(); } @@ -93,6 +98,30 @@ class Recognizer::Impl { } else if (config.decoder_config.method == "modified_beam_search") { decoder_ = std::make_unique( model_.get(), config.decoder_config.num_active_paths); + std::vector tmp; + /*each line in hotwords file is a string which is segmented by space*/ + std::ifstream file(config_.hotwords_file); + if (file) { + std::string line; + std::string word; + while (std::getline(file, line)) { + std::istringstream iss(line); + while(iss >> word){ + if (sym_.contains(word)) { + int number = sym_[word]; + tmp.push_back(number); + } else { + NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str()); + exit(-1); + } + } + hotwords_.push_back(tmp); + tmp.clear(); + } + } else { + NCNN_LOGE("open file failed: %s, hotwords will not be used", + config_.hotwords_file.c_str()); + } } else { NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str()); exit(-1); @@ -110,6 +139,30 @@ class Recognizer::Impl { } else if (config.decoder_config.method == "modified_beam_search") { decoder_ = std::make_unique( model_.get(), config.decoder_config.num_active_paths); + std::vector tmp; + /*each line in hotwords file is a string which is segmented by space*/ + std::ifstream file(config_.hotwords_file); + if (file) { + std::string line; + std::string word; + while (std::getline(file, line)) { + std::istringstream iss(line); + while(iss >> word){ + if (sym_.contains(word)) { + int number = sym_[word]; + tmp.push_back(number); + } else { + NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str()); + exit(-1); + } + } + hotwords_.push_back(tmp); + tmp.clear(); + } + } else { + NCNN_LOGE("open file failed: %s, hotwords will not be used", + config_.hotwords_file.c_str()); + } } else { NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str()); exit(-1); @@ -118,10 +171,29 @@ class Recognizer::Impl { #endif std::unique_ptr CreateStream() const { - auto stream = std::make_unique(config_.feat_config); - stream->SetResult(decoder_->GetEmptyResult()); - stream->SetStates(model_->GetEncoderInitStates()); - return stream; + if(hotwords_.empty()) { + auto stream = std::make_unique(config_.feat_config); + stream->SetResult(decoder_->GetEmptyResult()); + stream->SetStates(model_->GetEncoderInitStates()); + return stream; + } else { + auto r = decoder_->GetEmptyResult(); + auto context_graph = + std::make_shared(hotwords_, config_.hotwords_score); + auto stream = + std::make_unique(config_.feat_config, context_graph); + if (config_.decoder_config.method == "modified_beam_search" && + nullptr != stream->GetContextGraph()) { + std::cout<<"create contexts stream"<second.context_state = stream->GetContextGraph()->Root(); + } + } + stream->SetResult(r); + stream->SetStates(model_->GetEncoderInitStates()); + return stream; + } } bool IsReady(Stream *s) const { @@ -131,16 +203,24 @@ class Recognizer::Impl { void DecodeStream(Stream *s) const { int32_t segment = model_->Segment(); int32_t offset = model_->Offset(); + bool has_context_graph = false; + if (!has_context_graph && s->GetContextGraph()) { + has_context_graph = true; + } ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment); s->GetNumProcessedFrames() += offset; std::vector states = s->GetStates(); ncnn::Mat encoder_out; std::tie(encoder_out, states) = model_->RunEncoder(features, states); - s->SetStates(states); - decoder_->Decode(encoder_out, &s->GetResult()); + if (has_context_graph) { + decoder_->Decode(encoder_out, s, &s->GetResult()); + } else { + decoder_->Decode(encoder_out, &s->GetResult()); + } + s->SetStates(states); } bool IsEndpoint(Stream *s) const { @@ -158,6 +238,13 @@ class Recognizer::Impl { } void Reset(Stream *s) const { + auto r = decoder_->GetEmptyResult(); + if (config_.decoding_method == "modified_beam_search" && + nullptr != s->GetContextGraph()) { + for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) { + it->second.context_state = s->GetContextGraph()->Root(); + } + } // Caution: We need to keep the decoder output state ncnn::Mat decoder_out = s->GetResult().decoder_out; s->SetResult(decoder_->GetEmptyResult()); @@ -190,6 +277,7 @@ class Recognizer::Impl { std::unique_ptr decoder_; Endpoint endpoint_; SymbolTable sym_; + std::vector> hotwords_; }; Recognizer::Recognizer(const RecognizerConfig &config) @@ -206,6 +294,7 @@ std::unique_ptr Recognizer::CreateStream() const { return impl_->CreateStream(); } + bool Recognizer::IsReady(Stream *s) const { return impl_->IsReady(s); } void Recognizer::DecodeStream(Stream *s) const { impl_->DecodeStream(s); } diff --git a/sherpa-ncnn/csrc/recognizer.h b/sherpa-ncnn/csrc/recognizer.h index a2948456..97f95dd8 100644 --- a/sherpa-ncnn/csrc/recognizer.h +++ b/sherpa-ncnn/csrc/recognizer.h @@ -48,21 +48,32 @@ struct RecognizerConfig { FeatureExtractorConfig feat_config; ModelConfig model_config; DecoderConfig decoder_config; - + std::string decoding_method; + std::string hotwords_file; EndpointConfig endpoint_config; bool enable_endpoint = false; - + // used only for modified_beam_search + int32_t max_active_paths = 4; + /// used only for modified_beam_search + float hotwords_score = 1.5; RecognizerConfig() = default; RecognizerConfig(const FeatureExtractorConfig &feat_config, const ModelConfig &model_config, const DecoderConfig decoder_config, - const EndpointConfig &endpoint_config, bool enable_endpoint) + const EndpointConfig &endpoint_config, bool enable_endpoint, + const std::string &decoding_method, + const std::string &hotwords_file, + int32_t max_active_paths, float hotwords_score) : feat_config(feat_config), model_config(model_config), decoder_config(decoder_config), endpoint_config(endpoint_config), - enable_endpoint(enable_endpoint) {} + enable_endpoint(enable_endpoint), + decoding_method(decoding_method), + hotwords_file(hotwords_file), + max_active_paths(max_active_paths), + hotwords_score(hotwords_score) {} std::string ToString() const; }; diff --git a/sherpa-ncnn/csrc/sherpa-ncnn.cc b/sherpa-ncnn/csrc/sherpa-ncnn.cc index 2db8d2a4..b4e3cff9 100644 --- a/sherpa-ncnn/csrc/sherpa-ncnn.cc +++ b/sherpa-ncnn/csrc/sherpa-ncnn.cc @@ -18,7 +18,7 @@ */ #include - +#include #include #include // NOLINT #include @@ -28,7 +28,7 @@ #include "sherpa-ncnn/csrc/wave-reader.h" int32_t main(int32_t argc, char *argv[]) { - if (argc < 9 || argc > 11) { + if (argc < 9 || argc > 13) { const char *usage = R"usage( Usage: ./bin/sherpa-ncnn \ @@ -66,13 +66,23 @@ for a list of pre-trained models to download. config.model_config.joiner_opt.num_threads = num_threads; float expected_sampling_rate = 16000; - if (argc == 11) { + if (argc >= 11) { std::string method = argv[10]; if (method == "greedy_search" || method == "modified_beam_search") { config.decoder_config.method = method; } } - + std::cout<<"decode method:"<= 12) { + config.hotwords_file = argv[11]; + } else { + config.hotwords_file = ""; + } + if(argc == 13) { + config.hotwords_score = atof(argv[12]); + } else { + config.hotwords_file = 1.5; + } config.feat_config.sampling_rate = expected_sampling_rate; config.feat_config.feature_dim = 80; @@ -96,7 +106,6 @@ for a list of pre-trained models to download. auto begin = std::chrono::steady_clock::now(); std::cout << "Started!\n"; - auto stream = recognizer.CreateStream(); stream->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size()); diff --git a/sherpa-ncnn/csrc/stream.cc b/sherpa-ncnn/csrc/stream.cc index 5f26548b..f4d1eec7 100644 --- a/sherpa-ncnn/csrc/stream.cc +++ b/sherpa-ncnn/csrc/stream.cc @@ -22,8 +22,8 @@ namespace sherpa_ncnn { class Stream::Impl { public: - explicit Impl(const FeatureExtractorConfig &config) - : feat_extractor_(config) {} + explicit Impl(const FeatureExtractorConfig &config,ContextGraphPtr context_graph) + : feat_extractor_(config), context_graph_(context_graph) {} void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); @@ -62,16 +62,19 @@ class Stream::Impl { std::vector &GetStates() { return states_; } + const ContextGraphPtr &GetContextGraph() const { return context_graph_; } + private: FeatureExtractor feat_extractor_; + ContextGraphPtr context_graph_; int32_t num_processed_frames_ = 0; // before subsampling int32_t start_frame_index_ = 0; DecoderResult result_; std::vector states_; }; -Stream::Stream(const FeatureExtractorConfig &config) - : impl_(std::make_unique(config)) {} +Stream::Stream(const FeatureExtractorConfig &config, ContextGraphPtr context_graph) + : impl_(std::make_unique(config, context_graph)) {} Stream::~Stream() = default; @@ -108,4 +111,7 @@ void Stream::SetStates(const std::vector &states) { std::vector &Stream::GetStates() { return impl_->GetStates(); } +const ContextGraphPtr &Stream::GetContextGraph() const { + return impl_->GetContextGraph(); + } } // namespace sherpa_ncnn diff --git a/sherpa-ncnn/csrc/stream.h b/sherpa-ncnn/csrc/stream.h index 889f1936..87d536a6 100644 --- a/sherpa-ncnn/csrc/stream.h +++ b/sherpa-ncnn/csrc/stream.h @@ -24,11 +24,13 @@ #include "sherpa-ncnn/csrc/decoder.h" #include "sherpa-ncnn/csrc/features.h" +#include "sherpa-ncnn/csrc/context-graph.h" namespace sherpa_ncnn { class Stream { public: - explicit Stream(const FeatureExtractorConfig &config); + explicit Stream(const FeatureExtractorConfig &config = {}, + ContextGraphPtr context_graph = nullptr); ~Stream(); /** @@ -80,6 +82,12 @@ class Stream { void SetStates(const std::vector &states); std::vector &GetStates(); + /** + * Get the context graph corresponding to this stream. + * + * @return Return the context graph for this stream. + */ + const ContextGraphPtr &GetContextGraph() const; private: class Impl;