diff --git a/sherpa-onnx/csrc/offline-paraformer-decoder.h b/sherpa-onnx/csrc/offline-paraformer-decoder.h index 65781324b..1b783e88d 100644 --- a/sherpa-onnx/csrc/offline-paraformer-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-decoder.h @@ -23,8 +23,7 @@ class OfflineParaformerDecoder { /** Run beam search given the output from the paraformer model. * * @param log_probs A 3-D tensor of shape (N, T, vocab_size) - * @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t. - * log_probs[i].argmax(axis=-1) equals to token_num[i] + * @param token_num A 1-D tensor of shape (N). token_num equals to T. * * @return Return a vector of size `N` containing the decoded results. */ diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc index 54ce545ac..619b33495 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -4,28 +4,33 @@ #include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" +#include #include namespace sherpa_onnx { std::vector -OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/, - Ort::Value token_num) { - std::vector shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); +OfflineParaformerGreedySearchDecoder::Decode(Ort::Value log_probs, + Ort::Value /*token_num*/) { + std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); int32_t batch_size = shape[0]; int32_t num_tokens = shape[1]; + int32_t vocab_size = shape[2]; std::vector results(batch_size); - const int64_t *p = token_num.GetTensorData(); for (int32_t i = 0; i != batch_size; ++i) { + const float *p = + log_probs.GetTensorData() + i * num_tokens * vocab_size; for (int32_t k = 0; k != num_tokens; ++k) { - if (p[k] == eos_id_) break; + auto max_idx = static_cast( + std::distance(p, std::max_element(p, p + vocab_size))); + if (max_idx == eos_id_) break; - results[i].tokens.push_back(p[k]); - } + results[i].tokens.push_back(max_idx); - p += num_tokens; + p += vocab_size; + } } return results; diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h index 9ba177c91..1f48e8c84 100644 --- a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h @@ -17,7 +17,7 @@ class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { : eos_id_(eos_id) {} std::vector Decode( - Ort::Value /*log_probs*/, Ort::Value token_num) override; + Ort::Value log_probs, Ort::Value /*token_num*/) override; private: int32_t eos_id_;