-
Notifications
You must be signed in to change notification settings - Fork 470
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
af071e0
commit 26bb94f
Showing
11 changed files
with
103 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
// sherpa-onnx/csrc/online-cnn-bilstm-model.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/online-cnn-bilstm-model.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ | ||
#include <memory> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ | ||
|
||
|
@@ -29,7 +30,7 @@ | |
|
||
namespace sherpa_onnx { | ||
|
||
static const int32_t MAX_SEQ_LEN = 200; | ||
static const int32_t kMaxSeqLen = 200; | ||
|
||
class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | ||
public: | ||
|
@@ -54,7 +55,63 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
} | ||
#endif | ||
|
||
void encode_sentences(const std::string& text, | ||
std::string AddPunctuationWithCase(const std::string &text) const override { | ||
if (text.empty()) { | ||
return {}; | ||
} | ||
|
||
std::vector<int32_t> tokens_list; // N * kMaxSeqLen | ||
std::vector<int32_t> valids_list; // N * kMaxSeqLen | ||
std::vector<int32_t> label_len_list; // N | ||
|
||
EncodeSentences(text, tokens_list, valids_list, label_len_list); | ||
|
||
const auto &meta_data = model_.GetModelMetadata(); | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
int32_t n = label_len_list.size(); | ||
|
||
std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen}; | ||
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(), | ||
token_ids_shape.data(), token_ids_shape.size()); | ||
|
||
std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen}; | ||
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(), | ||
valid_ids_shape.data(), valid_ids_shape.size()); | ||
|
||
std::array<int64_t, 1> label_len_shape = {n}; | ||
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(), | ||
label_len_shape.data(), label_len_shape.size()); | ||
|
||
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); | ||
|
||
std::vector<int32_t> case_pred; | ||
std::vector<int32_t> punct_pred; | ||
const float* active_case_logits = pair.first.GetTensorData<float>(); | ||
const float* active_punct_logits = pair.second.GetTensorData<float>(); | ||
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); | ||
|
||
for (int32_t i = 0; i < case_logits_shape[0]; ++i) { | ||
const float* p_cur_case = active_case_logits + i * meta_data.num_cases; | ||
auto index_case = static_cast<int32_t>(std::distance( | ||
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); | ||
case_pred.push_back(index_case); | ||
|
||
const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; | ||
auto index_punct = static_cast<int32_t>(std::distance( | ||
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); | ||
punct_pred.push_back(index_punct); | ||
} | ||
|
||
std::string ans = DecodeSentences(text, case_pred, punct_pred); | ||
|
||
return ans; | ||
} | ||
|
||
private: | ||
void EncodeSentences(const std::string& text, | ||
std::vector<int32_t>& tokens_list, | ||
std::vector<int32_t>& valids_list, | ||
std::vector<int32_t>& label_len_list) const { | ||
|
@@ -71,19 +128,20 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
std::vector<int32_t> word_tokens; | ||
bpe_encoder_->Encode(word, &word_tokens); | ||
|
||
if (tokens.size() + word_tokens.size() > MAX_SEQ_LEN - 1) { | ||
int32_t seq_len = tokens.size() + word_tokens.size(); | ||
if (seq_len > kMaxSeqLen - 1) { | ||
tokens.push_back(2); // hardcode 2 now, 2 - </s> | ||
valids.push_back(1); | ||
|
||
label_len = std::count(valids.begin(), valids.end(), 1); | ||
|
||
while (tokens.size() < MAX_SEQ_LEN) { | ||
tokens.push_back(0); | ||
valids.push_back(0); | ||
if (tokens.size() < kMaxSeqLen) { | ||
tokens.resize(kMaxSeqLen, 0); | ||
valids.resize(kMaxSeqLen, 0); | ||
} | ||
|
||
assert(tokens.size() == MAX_SEQ_LEN); | ||
assert(valids.size() == MAX_SEQ_LEN); | ||
assert(tokens.size() == kMaxSeqLen); | ||
assert(valids.size() == kMaxSeqLen); | ||
|
||
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); | ||
valids_list.insert(valids_list.end(), valids.begin(), valids.end()); | ||
|
@@ -98,10 +156,10 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
|
||
tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); | ||
valids.push_back(1); // only the first sub word is valid | ||
int32_t remaining_size = word_tokens.size() - 1; | ||
while (remaining_size > 0) { | ||
valids.push_back(0); | ||
remaining_size--; | ||
int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1; | ||
if (remaining_size > 0) { | ||
int32_t valids_cur_size = static_cast<int32_t>(valids.size()); | ||
valids.resize(valids_cur_size + remaining_size, 0); | ||
} | ||
} | ||
|
||
|
@@ -111,21 +169,21 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
|
||
label_len = std::count(valids.begin(), valids.end(), 1); | ||
|
||
while (tokens.size() < MAX_SEQ_LEN) { | ||
tokens.push_back(0); | ||
valids.push_back(0); | ||
if (tokens.size() < kMaxSeqLen) { | ||
tokens.resize(kMaxSeqLen, 0); | ||
valids.resize(kMaxSeqLen, 0); | ||
} | ||
|
||
assert(tokens.size() == MAX_SEQ_LEN); | ||
assert(valids.size() == MAX_SEQ_LEN); | ||
assert(tokens.size() == kMaxSeqLen); | ||
assert(valids.size() == kMaxSeqLen); | ||
|
||
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); | ||
valids_list.insert(valids_list.end(), valids.begin(), valids.end()); | ||
label_len_list.push_back(label_len); | ||
} | ||
} | ||
|
||
std::string decode_sentences(const std::string& raw_text, | ||
std::string DecodeSentences(const std::string& raw_text, | ||
const std::vector<int32_t>& case_pred, | ||
const std::vector<int32_t>& punct_pred) const { | ||
std::string result_text; | ||
|
@@ -140,7 +198,7 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
assert(words.size() == case_pred.size()); | ||
assert(words.size() == punct_pred.size()); | ||
|
||
for (int i=0; i<words.size(); i++) { | ||
for (int32_t i = 0; i < words.size(); ++i) { | ||
std::string prefix = ((i != 0) ? " " : ""); | ||
result_text += prefix; | ||
switch (case_pred[i]) { | ||
|
@@ -158,7 +216,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
} | ||
case 3: // mix case | ||
{ | ||
//TODO | ||
// TODO: | ||
// Need to add a map containing supported mix case words so that we can fetch the predicted word from the map | ||
// e.g. mcdonald's -> McDonald's | ||
result_text += words[i]; | ||
break; | ||
} | ||
|
@@ -196,61 +256,6 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { | |
return result_text; | ||
} | ||
|
||
std::string AddPunctuationWithCase(const std::string &text) const override { | ||
if (text.empty()) { | ||
return {}; | ||
} | ||
|
||
std::vector<int32_t> tokens_list; // N * MAX_SEQ_LEN | ||
std::vector<int32_t> valids_list; // N * MAX_SEQ_LEN | ||
std::vector<int32_t> label_len_list; // N | ||
|
||
encode_sentences(text, tokens_list, valids_list, label_len_list); | ||
|
||
const auto &meta_data = model_.GetModelMetadata(); | ||
|
||
auto memory_info = | ||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); | ||
|
||
int32_t n = label_len_list.size(); | ||
|
||
std::array<int64_t, 2> token_ids_shape = {n, MAX_SEQ_LEN}; | ||
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(), | ||
token_ids_shape.data(), token_ids_shape.size()); | ||
|
||
std::array<int64_t, 2> valid_ids_shape = {n, MAX_SEQ_LEN}; | ||
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(), | ||
valid_ids_shape.data(), valid_ids_shape.size()); | ||
|
||
std::array<int64_t, 1> label_len_shape = {n}; | ||
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(), | ||
label_len_shape.data(), label_len_shape.size()); | ||
|
||
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); | ||
|
||
std::vector<int32_t> case_pred; | ||
std::vector<int32_t> punct_pred; | ||
const float* active_case_logits = pair.first.GetTensorData<float>(); | ||
const float* active_punct_logits = pair.second.GetTensorData<float>(); | ||
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); | ||
|
||
for (int i=0; i<case_logits_shape[0]; i++) { | ||
const float* p_cur_case = active_case_logits + i * meta_data.num_cases; | ||
auto index_case = static_cast<int32_t>(std::distance( | ||
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); | ||
case_pred.push_back(index_case); | ||
|
||
const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; | ||
auto index_punct = static_cast<int32_t>(std::distance( | ||
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); | ||
punct_pred.push_back(index_punct); | ||
} | ||
|
||
std::string ans = decode_sentences(text, case_pred, punct_pred); | ||
|
||
return ans; | ||
} | ||
|
||
private: | ||
OnlinePunctuationConfig config_; | ||
OnlineCNNBiLSTMModel model_; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
// sherpa-onnx/csrc/online-punctuation-impl.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#include "sherpa-onnx/csrc/online-punctuation-impl.h" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/online-punctuation-impl.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
// sherpa-onnx/csrc/online-punctuation-model-config.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h" | ||
|
||
|
@@ -33,7 +33,7 @@ bool OnlinePunctuationModelConfig::Validate() const { | |
} | ||
|
||
if (!FileExists(cnn_bilstm)) { | ||
SHERPA_ONNX_LOGE("--cnn-bilstm %s does not exist", | ||
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", | ||
cnn_bilstm.c_str()); | ||
return false; | ||
} | ||
|
@@ -44,7 +44,7 @@ bool OnlinePunctuationModelConfig::Validate() const { | |
} | ||
|
||
if (!FileExists(bpe_vocab)) { | ||
SHERPA_ONNX_LOGE("--bpe-vocab %s does not exist", | ||
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", | ||
bpe_vocab.c_str()); | ||
return false; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/online-punctuation-model-config.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
// sherpa-onnx/csrc/online-punctuation.cc | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#include "sherpa-onnx/csrc/online-punctuation.h" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/online-punctuation.h | ||
// | ||
// Copyright (c) 2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ | ||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc | ||
// | ||
// Copyright (c) 2022-2024 Xiaomi Corporation | ||
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems) | ||
|
||
#include <stdio.h> | ||
#include <iostream> | ||
|
||
|
@@ -26,6 +27,7 @@ Please download the model from: | |
"how are you i am fine thank you" | ||
The output text should look like below: | ||
"How are you? I am fine. Thank you." | ||
)usage"; | ||
|
||
sherpa_onnx::ParseOptions po(kUsageMessage); | ||
|
@@ -34,7 +36,7 @@ The output text should look like below: | |
po.Read(argc, argv); | ||
if (po.NumArgs() != 1) { | ||
fprintf(stderr, | ||
"Error: Please provide only 1 position argument containing the " | ||
"Error: Please provide only 1 positional argument containing the " | ||
"input text.\n\n"); | ||
po.PrintUsage(); | ||
exit(EXIT_FAILURE); | ||
|
@@ -55,8 +57,9 @@ The output text should look like below: | |
std::string text = po.GetArg(1); | ||
|
||
std::string text_with_punct_case = punct.AddPunctuationWithCase(text); | ||
fprintf(stderr, "Done\n"); | ||
|
||
const auto end = std::chrono::steady_clock::now(); | ||
fprintf(stderr, "Done\n"); | ||
|
||
float elapsed_seconds = | ||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||
|