Skip to content

Commit

Permalink
Fix a punctuation bug (#764)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Apr 13, 2024
1 parent b6ad043 commit 983df28
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.9.18")
set(SHERPA_ONNX_VERSION "1.9.19")

# Disable warning about
#
Expand Down
22 changes: 9 additions & 13 deletions sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
int32_t dot_index = -1;
int32_t comma_index = -1;

for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) {
for (int32_t m = this_punctuations.size() - 2; m >= 1; --m) {
int32_t punct_id = this_punctuations[m];

if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) {
Expand Down Expand Up @@ -126,27 +126,20 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
}
} else {
last = this_start + dot_index + 1;
}

if (dot_index != 1) {
punctuations.insert(punctuations.end(), this_punctuations.begin(),
this_punctuations.begin() + (dot_index + 1));
}
} // for (int32_t i = 0; i != num_segments; ++i)

if (punctuations.size() != token_ids.size() &&
punctuations.size() + 1 == token_ids.size()) {
punctuations.push_back(meta_data.dot_id);
}

if (punctuations.size() != token_ids.size()) {
SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened",
text.c_str(), static_cast<int32_t>(punctuations.size()),
static_cast<int32_t>(token_ids.size()));
return text;
}

std::string ans;

for (int32_t i = 0; i != static_cast<int32_t>(punctuations.size()); ++i) {
if (i > tokens.size()) {
break;
}
const std::string &w = tokens[i];
if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) {
ans.push_back(' ');
Expand All @@ -156,6 +149,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
ans.append(meta_data.id2punct[punctuations[i]]);
}
}
if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) {
ans.push_back(meta_data.dot_id);
}

return ans;
}
Expand Down

0 comments on commit 983df28

Please sign in to comment.