diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 3e6526563..aacabcf47 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -25,6 +25,7 @@ set(sources offline-ctc-fst-decoder.cc offline-ctc-greedy-search-decoder.cc offline-ctc-model.cc + offline-ctc-prefix-beam-search-decoder.cc offline-lm-config.cc offline-lm.cc offline-model-config.cc diff --git a/sherpa-onnx/csrc/hypothesis.cc b/sherpa-onnx/csrc/hypothesis.cc index ea332bcb5..848899acf 100644 --- a/sherpa-onnx/csrc/hypothesis.cc +++ b/sherpa-onnx/csrc/hypothesis.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023-2024 Xiaomi Corporation * Copyright (c) 2023 Pingfeng Luo */ @@ -10,37 +10,49 @@ namespace sherpa_onnx { -void Hypotheses::Add(Hypothesis hyp) { +void Hypotheses::Add(Hypothesis hyp, bool use_ctc /*= false */) { auto key = hyp.Key(); auto it = hyps_dict_.find(key); if (it == hyps_dict_.end()) { hyps_dict_[key] = std::move(hyp); } else { - it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + if (use_ctc) { + it->second.log_prob_b = + LogAdd()(it->second.log_prob_b, hyp.log_prob_b); + it->second.log_prob_nb = + LogAdd()(it->second.log_prob_nb, hyp.log_prob_nb); + } else { + it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + } } } -Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { +Hypothesis Hypotheses::GetMostProbable(bool length_norm, + bool use_ctc /*= false */) const { if (length_norm == false) { - return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), - [](const auto &left, auto &right) -> bool { - return left.second.TotalLogProb() < - right.second.TotalLogProb(); - }) + return std::max_element( + hyps_dict_.begin(), hyps_dict_.end(), + [use_ctc](const auto &left, const auto &right) -> bool { + return left.second.TotalLogProb(use_ctc) < + right.second.TotalLogProb(use_ctc); + }) ->second; } else { // for length_norm is true return std::max_element( hyps_dict_.begin(), hyps_dict_.end(), - [](const auto &left, const auto &right) -> bool { - return left.second.TotalLogProb() / left.second.ys.size() < - right.second.TotalLogProb() / right.second.ys.size(); + [use_ctc](const auto &left, const auto &right) -> bool { + return left.second.TotalLogProb(use_ctc) / + left.second.ys.size() < + right.second.TotalLogProb(use_ctc) / + right.second.ys.size(); }) ->second; } } -std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { +std::vector Hypotheses::GetTopK(int32_t k, bool length_norm, + bool use_ctc /*= false*/) const { k = std::max(k, 1); k = std::min(k, Size()); @@ -48,15 +60,16 @@ std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { if (length_norm == false) { std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), - [](const auto &a, const auto &b) { - return a.TotalLogProb() > b.TotalLogProb(); + [use_ctc](const auto &a, const auto &b) { + return a.TotalLogProb(use_ctc) > + b.TotalLogProb(use_ctc); }); } else { // for length_norm is true std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), - [](const auto &a, const auto &b) { - return a.TotalLogProb() / a.ys.size() > - b.TotalLogProb() / b.ys.size(); + [use_ctc](const auto &a, const auto &b) { + return a.TotalLogProb(use_ctc) / a.ys.size() > + b.TotalLogProb(use_ctc) / b.ys.size(); }); } diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 6a49bad35..5cc47c984 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023-2024 Xiaomi Corporation * Copyright (c) 2023 Pingfeng Luo * */ @@ -48,6 +48,12 @@ struct Hypothesis { // It contains only acoustic scores double log_prob = 0; + // The total score of ys which ends with blank token in log space + double log_prob_b = 0; + + // The total score of ys which ends with non blank token in log space + double log_prob_nb = -std::numeric_limits::infinity(); + // LM log prob if any. double lm_log_prob = 0; @@ -74,7 +80,21 @@ struct Hypothesis { const ContextState *context_state = nullptr) : ys(ys), log_prob(log_prob), context_state(context_state) {} - double TotalLogProb() const { return log_prob + lm_log_prob; } + explicit Hypothesis(const ContextState *context_state) + : context_state(context_state) {} + + double TotalLogProb(bool use_ctc = false) const { + return LogProb(use_ctc) + lm_log_prob; + } + + // The acoustic log probability + double LogProb(bool use_ctc = false) const { + if (use_ctc) { + return LogAdd()(log_prob_b, log_prob_nb); + } else { + return log_prob; + } + } // If two Hypotheses have the same `Key`, then they contain // the same token sequence. @@ -112,20 +132,23 @@ class Hypotheses { // Add hyp to this object. If it already exists, its log_prob // is updated with the given hyp using log-sum-exp. - void Add(Hypothesis hyp); + void Add(Hypothesis hyp, bool use_ctc = false); // Get the hyp that has the largest log_prob. // If length_norm is true, hyp's log_prob is divided by // len(hyp.ys) before comparison. - Hypothesis GetMostProbable(bool length_norm) const; + Hypothesis GetMostProbable(bool length_norm, bool use_ctc = false) const; // Get the k hyps that have the largest log_prob. // If length_norm is true, hyp's log_prob is divided by // len(hyp.ys) before comparison. - std::vector GetTopK(int32_t k, bool length_norm) const; + std::vector GetTopK(int32_t k, bool length_norm, + bool use_ctc = false) const; int32_t Size() const { return hyps_dict_.size(); } + std::vector ToList() const { return Vec(); } + std::string ToString() const { std::ostringstream os; for (const auto &p : hyps_dict_) { diff --git a/sherpa-onnx/csrc/offline-ctc-decoder.h b/sherpa-onnx/csrc/offline-ctc-decoder.h index c9d1b36ff..b5914e939 100644 --- a/sherpa-onnx/csrc/offline-ctc-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-decoder.h @@ -8,6 +8,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-stream.h" namespace sherpa_onnx { @@ -42,7 +43,8 @@ class OfflineCtcDecoder { * @return Return a vector of size `N` containing the decoded results. */ virtual std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) = 0; + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss = nullptr, int32_t n = 0) = 0; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc index 6c9df3fd3..e62d76a26 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc @@ -124,7 +124,8 @@ OfflineCtcFstDecoder::OfflineCtcFstDecoder( : config_(config), fst_(ReadGraph(config_.graph)) {} std::vector OfflineCtcFstDecoder::Decode( - Ort::Value log_probs, Ort::Value log_probs_length) { + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) { std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); assert(static_cast(shape.size()) == 3); diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h index 2b33c14e8..0291f66cf 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h @@ -19,8 +19,10 @@ class OfflineCtcFstDecoder : public OfflineCtcDecoder { public: explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config); - std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) override; + std::vector Decode(Ort::Value log_probs, + Ort::Value log_probs_length, + OfflineStream **ss = nullptr, + int32_t n = 0) override; private: OfflineCtcFstDecoderConfig config_; diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc index 59d16f5d3..2aca90dd4 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc @@ -1,4 +1,4 @@ -// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc // // Copyright (c) 2023 Xiaomi Corporation @@ -13,7 +13,8 @@ namespace sherpa_onnx { std::vector OfflineCtcGreedySearchDecoder::Decode( - Ort::Value log_probs, Ort::Value log_probs_length) { + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) { std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = static_cast(shape[0]); int32_t num_frames = static_cast(shape[1]); diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h index ccc2f728a..ce2c19904 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h @@ -16,8 +16,10 @@ class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder { explicit OfflineCtcGreedySearchDecoder(int32_t blank_id) : blank_id_(blank_id) {} - std::vector Decode( - Ort::Value log_probs, Ort::Value log_probs_length) override; + std::vector Decode(Ort::Value log_probs, + Ort::Value log_probs_length, + OfflineStream **ss = nullptr, + int32_t n = 0) override; private: int32_t blank_id_; diff --git a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc new file mode 100644 index 000000000..273f65210 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc @@ -0,0 +1,132 @@ +// sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/context-graph.h" +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +static std::vector StepWorker(const float *p_log_probs, + std::vector &hyps, + int32_t blank_id, int32_t vocab_size, + int32_t max_active_paths, + const ContextGraph *context_graph) { + auto topk = TopkIndex(p_log_probs, vocab_size, max_active_paths); + Hypotheses next_hyps; + for (auto &hyp : hyps) { + for (auto k : topk) { + Hypothesis new_hyp = hyp; + int32_t new_token = k; + float log_prob = p_log_probs[k]; + bool update_prefix = false; + if (new_token == blank_id) { + // Case 0: *a + ε => *a + // *aε + ε => *a + // Prefix does not change, update log_prob of blank + new_hyp.log_prob_nb = -std::numeric_limits::infinity(); + new_hyp.log_prob_b = hyp.LogProb(true) + log_prob; + next_hyps.Add(std::move(new_hyp)); + } else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) { + // Case 1: *a + a => *a + // Prefix does not change, update log_prob of non_blank + new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob; + new_hyp.log_prob_b = -std::numeric_limits::infinity(); + next_hyps.Add(std::move(new_hyp)); + + // Case 2: *aε + a => *aa + // Prefix changes, update log_prob of blank + new_hyp = hyp; + new_hyp.ys.push_back(new_token); + new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; + new_hyp.log_prob_b = -std::numeric_limits::infinity(); + update_prefix = true; + } else { + // Case 3: *a + b => *ab, *aε + b => *ab + // Prefix changes, update log_prob of non_blank + // Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys.push_back(new_token); + new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; + new_hyp.log_prob_b = -std::numeric_limits::infinity(); + update_prefix = true; + } + + if (update_prefix) { + float lm_log_prob = hyp.lm_log_prob; + if (context_graph != nullptr && hyp.context_state != nullptr) { + auto context_res = + context_graph->ForwardOneStep(hyp.context_state, new_token); + lm_log_prob = lm_log_prob + std::get<0>(context_res); + new_hyp.context_state = std::get<1>(context_res); + } + new_hyp.lm_log_prob = lm_log_prob; + next_hyps.Add(std::move(new_hyp)); + } + } + } + return next_hyps.GetTopK(max_active_paths, false, true); +} + +std::vector OfflineCtcPrefixBeamSearchDecoder::Decode( + Ort::Value log_probs, Ort::Value log_probs_length, + OfflineStream **ss /*= nullptr*/, int32_t n /*= 0*/) { + std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = static_cast(shape[0]); + int32_t num_frames = static_cast(shape[1]); + int32_t vocab_size = static_cast(shape[2]); + + const int64_t *p_log_probs_length = log_probs_length.GetTensorData(); + + std::vector ans; + ans.reserve(batch_size); + + std::vector> cur; + cur.reserve(batch_size); + + std::vector context_graphs(batch_size, nullptr); + + for (int32_t i = 0; i < batch_size; ++i) { + const ContextState *context_state = nullptr; + if (ss != nullptr) { + context_graphs[i] = ss[i]->GetContextGraph(); + if (context_graphs[i] != nullptr) + context_state = context_graphs[i]->Root(); + } + Hypothesis hyp(context_state); + cur.emplace_back(std::vector({hyp})); + } + + for (int32_t t = 0; t < num_frames; ++t) { + for (int32_t b = 0; b < batch_size; ++b) { + if (t < p_log_probs_length[b]) { + const float *p_log_probs = log_probs.GetTensorData() + + b * num_frames * vocab_size + t * vocab_size; + cur[b] = StepWorker(p_log_probs, cur[b], blank_id_, vocab_size, + max_active_paths_, context_graphs[b].get()); + // for (auto &x : cur[b]) { + // SHERPA_ONNX_LOGE("step : %d, key : %s, ac : %f, lm : %f", t, + // x.Key().c_str(), x.LogProb(true), x.lm_log_prob); + // } + // SHERPA_ONNX_LOGE("\n"); + } + } + } + + for (int32_t b = 0; b != batch_size; ++b) { + Hypotheses hyps(cur[b]); + Hypothesis best_hyp = hyps.GetMostProbable(false, true); + OfflineCtcDecoderResult r; + r.tokens = best_hyp.ys; + ans.push_back(std::move(r)); + } + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h new file mode 100644 index 000000000..1504bb870 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_PREFIX_BEAM_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_PREFIX_BEAM_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" + +namespace sherpa_onnx { + +class OfflineCtcPrefixBeamSearchDecoder : public OfflineCtcDecoder { + public: + OfflineCtcPrefixBeamSearchDecoder(int32_t max_active_paths, int32_t blank_id) + : max_active_paths_(max_active_paths), blank_id_(blank_id) {} + + std::vector Decode(Ort::Value log_probs, + Ort::Value log_probs_length, + OfflineStream **ss = nullptr, + int32_t n = 0) override; + + private: + int32_t max_active_paths_; + int32_t blank_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_PREFIX_BEAM_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 7bbe6938c..4afedeb93 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -7,6 +7,7 @@ #include #include +#include // NOLINT #include #include #include @@ -21,9 +22,12 @@ #include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { @@ -125,7 +129,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { // asset_manager decoder_ = std::make_unique( config_.ctc_fst_decoder_config); - } else if (config_.decoding_method == "greedy_search") { + } else if (config_.decoding_method == "greedy_search" || + config_.decoding_method == "prefix_beam_search") { if (!symbol_table_.Contains("") && !symbol_table_.Contains("") && !symbol_table_.Contains("")) { @@ -146,16 +151,70 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { blank_id = symbol_table_[""]; } - decoder_ = std::make_unique(blank_id); + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique(blank_id); + } else { + if (!config_.model_config.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model_config.bpe_vocab); + } + + if (!config_.hotwords_file.empty()) { + InitHotwords(); + } + + decoder_ = std::make_unique( + config_.max_active_paths, blank_id); + } } else { - SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", - config_.decoding_method.c_str()); + SHERPA_ONNX_LOGE( + "Only greedy_search and prefix_beam_search are supported at present. " + "Given %s", + config_.decoding_method.c_str()); exit(-1); } } + std::unique_ptr CreateStream( + const std::string &hotwords) const override { + auto hws = std::regex_replace(hotwords, std::regex("/"), "\n"); + std::istringstream is(hws); + std::vector> current; + std::vector current_scores; + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), ¤t, ¤t_scores)) { + SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s", + hotwords.c_str()); + } + + int32_t num_default_hws = hotwords_.size(); + int32_t num_hws = current.size(); + + current.insert(current.end(), hotwords_.begin(), hotwords_.end()); + + if (!current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else if (!current_scores.empty() && boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_default_hws, + config_.hotwords_score); + } else if (current_scores.empty() && !boost_scores_.empty()) { + current_scores.insert(current_scores.end(), num_hws, + config_.hotwords_score); + current_scores.insert(current_scores.end(), boost_scores_.begin(), + boost_scores_.end()); + } else { + // Do nothing. + } + + auto context_graph = std::make_shared( + current, config_.hotwords_score, current_scores); + return std::make_unique(config_.feat_config, context_graph); + } + std::unique_ptr CreateStream() const override { - return std::make_unique(config_.feat_config); + return std::make_unique(config_.feat_config, + hotwords_graph_); } void DecodeStreams(OfflineStream **ss, int32_t n) const override { @@ -209,7 +268,7 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { -23.025850929940457f); auto t = model_->Forward(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]), ss, n); int32_t frame_shift_ms = 10; for (int32_t i = 0; i != n; ++i) { @@ -246,7 +305,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { x_length_shape.data(), x_length_shape.size()); auto t = model_->Forward(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + + OfflineStream *ss[1] = {s}; + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]), ss, 1); int32_t frame_shift_ms = 10; auto r = Convert(results[0], symbol_table_, frame_shift_ms, @@ -255,9 +316,60 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { s->SetResult(r); } + void InitHotwords() { + // each line in hotwords_file contains space-separated words + + std::ifstream is(config_.hotwords_file); + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } + +#if __ANDROID_API__ >= 9 + void InitHotwords(AAssetManager *mgr) { + // each line in hotwords_file contains space-separated words + + auto buf = ReadFile(mgr, config_.hotwords_file); + + std::istringstream is(std::string(buf.begin(), buf.end())); + + if (!is) { + SHERPA_ONNX_LOGE("Open hotwords file failed: %s", + config_.hotwords_file.c_str()); + exit(-1); + } + + if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_, + bpe_encoder_.get(), &hotwords_, &boost_scores_)) { + SHERPA_ONNX_LOGE( + "Failed to encode some hotwords, skip them already, see logs above " + "for details."); + } + hotwords_graph_ = std::make_shared( + hotwords_, config_.hotwords_score, boost_scores_); + } +#endif + private: OfflineRecognizerConfig config_; SymbolTable symbol_table_; + + std::vector> hotwords_; + std::vector boost_scores_; + ContextGraphPtr hotwords_graph_; + std::unique_ptr bpe_encoder_; + std::unique_ptr model_; std::unique_ptr decoder_; }; diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index f73e35ad6..88e4a271d 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -68,7 +68,8 @@ bool OfflineRecognizerConfig::Validate() const { } } - if (!hotwords_file.empty() && decoding_method != "modified_beam_search") { + if (!hotwords_file.empty() && (decoding_method != "modified_beam_search" && + decoding_method != "prefix_beam_search")) { SHERPA_ONNX_LOGE( "Please use --decoding-method=modified_beam_search if you" " provide --hotwords-file. Given --decoding-method='%s'", @@ -157,7 +158,7 @@ void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const { } void OfflineRecognizer::SetConfig(const OfflineRecognizerConfig &config) { - impl_->SetConfig(config); + impl_->SetConfig(config); } OfflineRecognizerConfig OfflineRecognizer::GetConfig() const {