-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use log probs for paraformer #120
Conversation
sherpa-onnx/csrc/math.h
Outdated
@@ -103,5 +103,18 @@ std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) { | |||
return index; | |||
} | |||
|
|||
template <class T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow
sherpa-onnx/sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Lines 35 to 39 in 80060c2
auto y = static_cast<int64_t>(std::distance( | |
static_cast<const float *>(p_log_probs), | |
std::max_element( | |
static_cast<const float *>(p_log_probs), | |
static_cast<const float *>(p_log_probs) + vocab_size))); |
to use std::max_element
to replace ArgMax()
.
You don't need to reimplement the wheel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p_log_probs
is const float *
, why do we use static_cast
again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
p_log_probs
isconst float *
, why do we usestatic_cast
again?
yes, you are right. please remove the cast.
const float *p = log_probs.GetTensorData<float>(); | ||
for (int32_t i = 0; i != batch_size; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const float *p = log_probs.GetTensorData<float>(); | |
for (int32_t i = 0; i != batch_size; ++i) { | |
for (int32_t i = 0; i != batch_size; ++i) { | |
const float *p = log_probs.GetTensorData<float>() + i * num_tokens * vocab_size; |
Thanks! |
* Use log probs for paraformer * Fix
No description provided.