From 01379020c99f0bd1549f68cc32adb6daa28e1979 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 17 Oct 2024 14:52:51 +0800 Subject: [PATCH] Add ctc prefix beam search --- sherpa-onnx/csrc/CMakeLists.txt | 1 + sherpa-onnx/csrc/hypothesis.cc | 49 ++++++--- sherpa-onnx/csrc/hypothesis.h | 30 ++++- .../csrc/offline-ctc-greedy-search-decoder.cc | 2 +- .../offline-ctc-prefix-beam-search-decoder.cc | 103 ++++++++++++++++++ .../offline-ctc-prefix-beam-search-decoder.h | 29 +++++ .../csrc/offline-recognizer-ctc-impl.h | 17 ++- 7 files changed, 203 insertions(+), 28 deletions(-) create mode 100644 sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc create mode 100644 sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 3e6526563..aacabcf47 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -25,6 +25,7 @@ set(sources offline-ctc-fst-decoder.cc offline-ctc-greedy-search-decoder.cc offline-ctc-model.cc + offline-ctc-prefix-beam-search-decoder.cc offline-lm-config.cc offline-lm.cc offline-model-config.cc diff --git a/sherpa-onnx/csrc/hypothesis.cc b/sherpa-onnx/csrc/hypothesis.cc index ea332bcb5..848899acf 100644 --- a/sherpa-onnx/csrc/hypothesis.cc +++ b/sherpa-onnx/csrc/hypothesis.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023-2024 Xiaomi Corporation * Copyright (c) 2023 Pingfeng Luo */ @@ -10,37 +10,49 @@ namespace sherpa_onnx { -void Hypotheses::Add(Hypothesis hyp) { +void Hypotheses::Add(Hypothesis hyp, bool use_ctc /*= false */) { auto key = hyp.Key(); auto it = hyps_dict_.find(key); if (it == hyps_dict_.end()) { hyps_dict_[key] = std::move(hyp); } else { - it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + if (use_ctc) { + it->second.log_prob_b = + LogAdd()(it->second.log_prob_b, hyp.log_prob_b); + it->second.log_prob_nb = + LogAdd()(it->second.log_prob_nb, hyp.log_prob_nb); + } else { + it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + } } } -Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { +Hypothesis Hypotheses::GetMostProbable(bool length_norm, + bool use_ctc /*= false */) const { if (length_norm == false) { - return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), - [](const auto &left, auto &right) -> bool { - return left.second.TotalLogProb() < - right.second.TotalLogProb(); - }) + return std::max_element( + hyps_dict_.begin(), hyps_dict_.end(), + [use_ctc](const auto &left, const auto &right) -> bool { + return left.second.TotalLogProb(use_ctc) < + right.second.TotalLogProb(use_ctc); + }) ->second; } else { // for length_norm is true return std::max_element( hyps_dict_.begin(), hyps_dict_.end(), - [](const auto &left, const auto &right) -> bool { - return left.second.TotalLogProb() / left.second.ys.size() < - right.second.TotalLogProb() / right.second.ys.size(); + [use_ctc](const auto &left, const auto &right) -> bool { + return left.second.TotalLogProb(use_ctc) / + left.second.ys.size() < + right.second.TotalLogProb(use_ctc) / + right.second.ys.size(); }) ->second; } } -std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { +std::vector Hypotheses::GetTopK(int32_t k, bool length_norm, + bool use_ctc /*= false*/) const { k = std::max(k, 1); k = std::min(k, Size()); @@ -48,15 +60,16 @@ std::vector Hypotheses::GetTopK(int32_t k, bool length_norm) const { if (length_norm == false) { std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), - [](const auto &a, const auto &b) { - return a.TotalLogProb() > b.TotalLogProb(); + [use_ctc](const auto &a, const auto &b) { + return a.TotalLogProb(use_ctc) > + b.TotalLogProb(use_ctc); }); } else { // for length_norm is true std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(), - [](const auto &a, const auto &b) { - return a.TotalLogProb() / a.ys.size() > - b.TotalLogProb() / b.ys.size(); + [use_ctc](const auto &a, const auto &b) { + return a.TotalLogProb(use_ctc) / a.ys.size() > + b.TotalLogProb(use_ctc) / b.ys.size(); }); } diff --git a/sherpa-onnx/csrc/hypothesis.h b/sherpa-onnx/csrc/hypothesis.h index 6a49bad35..c9ccf80a8 100644 --- a/sherpa-onnx/csrc/hypothesis.h +++ b/sherpa-onnx/csrc/hypothesis.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023 Xiaomi Corporation + * Copyright (c) 2023-2024 Xiaomi Corporation * Copyright (c) 2023 Pingfeng Luo * */ @@ -48,6 +48,12 @@ struct Hypothesis { // It contains only acoustic scores double log_prob = 0; + // The total score of ys which ends with blank token in log space + double log_prob_b = 0; + + // The total score of ys which ends with non blank token in log space + double log_prob_nb = -std::numeric_limits::infinity(); + // LM log prob if any. double lm_log_prob = 0; @@ -74,7 +80,18 @@ struct Hypothesis { const ContextState *context_state = nullptr) : ys(ys), log_prob(log_prob), context_state(context_state) {} - double TotalLogProb() const { return log_prob + lm_log_prob; } + double TotalLogProb(bool use_ctc = false) const { + return LogProb(use_ctc) + lm_log_prob; + } + + // The acoustic log probability + double LogProb(bool use_ctc = false) const { + if (use_ctc) { + return LogAdd()(log_prob_b, log_prob_nb); + } else { + return log_prob; + } + } // If two Hypotheses have the same `Key`, then they contain // the same token sequence. @@ -112,20 +129,23 @@ class Hypotheses { // Add hyp to this object. If it already exists, its log_prob // is updated with the given hyp using log-sum-exp. - void Add(Hypothesis hyp); + void Add(Hypothesis hyp, bool use_ctc = false); // Get the hyp that has the largest log_prob. // If length_norm is true, hyp's log_prob is divided by // len(hyp.ys) before comparison. - Hypothesis GetMostProbable(bool length_norm) const; + Hypothesis GetMostProbable(bool length_norm, bool use_ctc = false) const; // Get the k hyps that have the largest log_prob. // If length_norm is true, hyp's log_prob is divided by // len(hyp.ys) before comparison. - std::vector GetTopK(int32_t k, bool length_norm) const; + std::vector GetTopK(int32_t k, bool length_norm, + bool use_ctc = false) const; int32_t Size() const { return hyps_dict_.size(); } + std::vector ToList() const { return Vec(); } + std::string ToString() const { std::ostringstream os; for (const auto &p : hyps_dict_) { diff --git a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc index 59d16f5d3..8196e28b3 100644 --- a/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc @@ -1,4 +1,4 @@ -// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h +// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc new file mode 100644 index 000000000..10937e81b --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc @@ -0,0 +1,103 @@ +// sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/hypothesis.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +static std::vector StepWorker(const float *p_log_probs, + std::vector &hyps, + int32_t blank_id, int32_t vocab_size, + int32_t max_active_paths) { + auto topk = TopkIndex(p_log_probs, vocab_size, max_active_paths); + Hypotheses next_hyps; + for (auto &hyp : hyps) { + Hypothesis new_hyp = hyp; + for (auto k : topk) { + int32_t new_token = k; + float log_prob = p_log_probs[k]; + if (new_token == blank_id) { + // Case 0: *a + ε => *a + // *aε + ε => *a + // Prefix does not change, update log_prob of blank + new_hyp.log_prob_nb = -std::numeric_limits::infinity(); + new_hyp.log_prob_b = hyp.LogProb(true) + log_prob; + next_hyps.Add(std::move(new_hyp)); + } else if (hyp.ys.size() > 0 && hyp.ys.back() == new_token) { + // Case 1: *a + a => *a + // Prefix does not change, update log_prob of non_blank + new_hyp.log_prob_nb = hyp.log_prob_nb + log_prob; + new_hyp.log_prob_b = -std::numeric_limits::infinity(); + + next_hyps.Add(std::move(new_hyp)); + + // Case 2: *aε + a => *aa + // Prefix changes, update log_prob of blank + new_hyp = hyp; + new_hyp.ys.push_back(new_token); + new_hyp.log_prob_nb = hyp.log_prob_b + log_prob; + new_hyp.log_prob_b = -std::numeric_limits::infinity(); + next_hyps.Add(std::move(new_hyp)); + } else { + // Case 3: *a + b => *ab, *aε + b => *ab + // Prefix changes, update log_prob of non_blank + // Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys.push_back(new_token); + new_hyp.log_prob_nb = hyp.LogProb(true) + log_prob; + new_hyp.log_prob_b = -std::numeric_limits::infinity(); + next_hyps.Add(std::move(new_hyp)); + } + } + } + return next_hyps.GetTopK(max_active_paths, false, true); +} + +std::vector OfflineCtcPrefixBeamSearchDecoder::Decode( + Ort::Value log_probs, Ort::Value log_probs_length) { + std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = static_cast(shape[0]); + int32_t num_frames = static_cast(shape[1]); + int32_t vocab_size = static_cast(shape[2]); + + const int64_t *p_log_probs_length = log_probs_length.GetTensorData(); + + std::vector ans; + ans.reserve(batch_size); + + std::vector> cur; + cur.reserve(batch_size); + + for (int32_t i = 0; i < batch_size; ++i) { + cur.emplace_back(std::vector({Hypothesis()})); + } + + for (int32_t t = 0; t < num_frames; ++t) { + for (int32_t b = 0; b < batch_size; ++b) { + if (t < p_log_probs_length[b]) { + const float *p_log_probs = log_probs.GetTensorData() + + b * num_frames * vocab_size + t * vocab_size; + cur[b] = StepWorker(p_log_probs, cur[b], blank_id_, vocab_size, + max_active_paths_); + } + } + } + + for (int32_t b = 0; b != batch_size; ++b) { + Hypotheses hyps(cur[b]); + Hypothesis best_hyp = hyps.GetMostProbable(false, true); + OfflineCtcDecoderResult r; + r.tokens = best_hyp.ys; + ans.push_back(std::move(r)); + } + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h new file mode 100644 index 000000000..97449f572 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_PREFIX_BEAM_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_PREFIX_BEAM_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" + +namespace sherpa_onnx { + +class OfflineCtcPrefixBeamSearchDecoder : public OfflineCtcDecoder { + public: + OfflineCtcPrefixBeamSearchDecoder(int32_t max_active_paths, int32_t blank_id) + : max_active_paths_(max_active_paths), blank_id_(blank_id) {} + + std::vector Decode( + Ort::Value log_probs, Ort::Value log_probs_length) override; + + private: + int32_t max_active_paths_; + int32_t blank_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_PREFIX_BEAM_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 7bbe6938c..ebf7b25b1 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -21,6 +21,7 @@ #include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-ctc-prefix-beam-search-decoder.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" #include "sherpa-onnx/csrc/pad-sequence.h" #include "sherpa-onnx/csrc/symbol-table.h" @@ -125,7 +126,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { // asset_manager decoder_ = std::make_unique( config_.ctc_fst_decoder_config); - } else if (config_.decoding_method == "greedy_search") { + } else if (config_.decoding_method == "greedy_search" || + config_.decoding_method == "prefix_beam_search") { if (!symbol_table_.Contains("") && !symbol_table_.Contains("") && !symbol_table_.Contains("")) { @@ -146,10 +148,17 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { blank_id = symbol_table_[""]; } - decoder_ = std::make_unique(blank_id); + if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique(blank_id); + } else { + decoder_ = std::make_unique( + config_.max_active_paths, blank_id); + } } else { - SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", - config_.decoding_method.c_str()); + SHERPA_ONNX_LOGE( + "Only greedy_search and prefix_beam_search are supported at present. " + "Given %s", + config_.decoding_method.c_str()); exit(-1); } }