Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
frankyoujian committed Aug 6, 2024
1 parent af071e0 commit 26bb94f
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 90 deletions.
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-cnn-bilstm-model.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"

Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-cnn-bilstm-model.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-cnn-bilstm-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
#include <memory>
Expand Down
157 changes: 81 additions & 76 deletions sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_

Expand Down Expand Up @@ -29,7 +30,7 @@

namespace sherpa_onnx {

static const int32_t MAX_SEQ_LEN = 200;
static const int32_t kMaxSeqLen = 200;

class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
public:
Expand All @@ -54,7 +55,63 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
#endif

void encode_sentences(const std::string& text,
std::string AddPunctuationWithCase(const std::string &text) const override {
if (text.empty()) {
return {};
}

std::vector<int32_t> tokens_list; // N * kMaxSeqLen
std::vector<int32_t> valids_list; // N * kMaxSeqLen
std::vector<int32_t> label_len_list; // N

EncodeSentences(text, tokens_list, valids_list, label_len_list);

const auto &meta_data = model_.GetModelMetadata();

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

int32_t n = label_len_list.size();

std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),
token_ids_shape.data(), token_ids_shape.size());

std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),
valid_ids_shape.data(), valid_ids_shape.size());

std::array<int64_t, 1> label_len_shape = {n};
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),
label_len_shape.data(), label_len_shape.size());

auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len));

std::vector<int32_t> case_pred;
std::vector<int32_t> punct_pred;
const float* active_case_logits = pair.first.GetTensorData<float>();
const float* active_punct_logits = pair.second.GetTensorData<float>();
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape();

for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
const float* p_cur_case = active_case_logits + i * meta_data.num_cases;
auto index_case = static_cast<int32_t>(std::distance(
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
case_pred.push_back(index_case);

const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations;
auto index_punct = static_cast<int32_t>(std::distance(
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations)));
punct_pred.push_back(index_punct);
}

std::string ans = DecodeSentences(text, case_pred, punct_pred);

return ans;
}

private:
void EncodeSentences(const std::string& text,
std::vector<int32_t>& tokens_list,
std::vector<int32_t>& valids_list,
std::vector<int32_t>& label_len_list) const {
Expand All @@ -71,19 +128,20 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
std::vector<int32_t> word_tokens;
bpe_encoder_->Encode(word, &word_tokens);

if (tokens.size() + word_tokens.size() > MAX_SEQ_LEN - 1) {
int32_t seq_len = tokens.size() + word_tokens.size();
if (seq_len > kMaxSeqLen - 1) {
tokens.push_back(2); // hardcode 2 now, 2 - </s>
valids.push_back(1);

label_len = std::count(valids.begin(), valids.end(), 1);

while (tokens.size() < MAX_SEQ_LEN) {
tokens.push_back(0);
valids.push_back(0);
if (tokens.size() < kMaxSeqLen) {
tokens.resize(kMaxSeqLen, 0);
valids.resize(kMaxSeqLen, 0);
}

assert(tokens.size() == MAX_SEQ_LEN);
assert(valids.size() == MAX_SEQ_LEN);
assert(tokens.size() == kMaxSeqLen);
assert(valids.size() == kMaxSeqLen);

tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
Expand All @@ -98,10 +156,10 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {

tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());
valids.push_back(1); // only the first sub word is valid
int32_t remaining_size = word_tokens.size() - 1;
while (remaining_size > 0) {
valids.push_back(0);
remaining_size--;
int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;
if (remaining_size > 0) {
int32_t valids_cur_size = static_cast<int32_t>(valids.size());
valids.resize(valids_cur_size + remaining_size, 0);
}
}

Expand All @@ -111,21 +169,21 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {

label_len = std::count(valids.begin(), valids.end(), 1);

while (tokens.size() < MAX_SEQ_LEN) {
tokens.push_back(0);
valids.push_back(0);
if (tokens.size() < kMaxSeqLen) {
tokens.resize(kMaxSeqLen, 0);
valids.resize(kMaxSeqLen, 0);
}

assert(tokens.size() == MAX_SEQ_LEN);
assert(valids.size() == MAX_SEQ_LEN);
assert(tokens.size() == kMaxSeqLen);
assert(valids.size() == kMaxSeqLen);

tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
label_len_list.push_back(label_len);
}
}

std::string decode_sentences(const std::string& raw_text,
std::string DecodeSentences(const std::string& raw_text,
const std::vector<int32_t>& case_pred,
const std::vector<int32_t>& punct_pred) const {
std::string result_text;
Expand All @@ -140,7 +198,7 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
assert(words.size() == case_pred.size());
assert(words.size() == punct_pred.size());

for (int i=0; i<words.size(); i++) {
for (int32_t i = 0; i < words.size(); ++i) {
std::string prefix = ((i != 0) ? " " : "");
result_text += prefix;
switch (case_pred[i]) {
Expand All @@ -158,7 +216,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
}
case 3: // mix case
{
//TODO
// TODO:
// Need to add a map containing supported mix case words so that we can fetch the predicted word from the map
// e.g. mcdonald's -> McDonald's
result_text += words[i];
break;
}
Expand Down Expand Up @@ -196,61 +256,6 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
return result_text;
}

std::string AddPunctuationWithCase(const std::string &text) const override {
if (text.empty()) {
return {};
}

std::vector<int32_t> tokens_list; // N * MAX_SEQ_LEN
std::vector<int32_t> valids_list; // N * MAX_SEQ_LEN
std::vector<int32_t> label_len_list; // N

encode_sentences(text, tokens_list, valids_list, label_len_list);

const auto &meta_data = model_.GetModelMetadata();

auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

int32_t n = label_len_list.size();

std::array<int64_t, 2> token_ids_shape = {n, MAX_SEQ_LEN};
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),
token_ids_shape.data(), token_ids_shape.size());

std::array<int64_t, 2> valid_ids_shape = {n, MAX_SEQ_LEN};
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),
valid_ids_shape.data(), valid_ids_shape.size());

std::array<int64_t, 1> label_len_shape = {n};
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),
label_len_shape.data(), label_len_shape.size());

auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len));

std::vector<int32_t> case_pred;
std::vector<int32_t> punct_pred;
const float* active_case_logits = pair.first.GetTensorData<float>();
const float* active_punct_logits = pair.second.GetTensorData<float>();
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape();

for (int i=0; i<case_logits_shape[0]; i++) {
const float* p_cur_case = active_case_logits + i * meta_data.num_cases;
auto index_case = static_cast<int32_t>(std::distance(
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
case_pred.push_back(index_case);

const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations;
auto index_punct = static_cast<int32_t>(std::distance(
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations)));
punct_pred.push_back(index_punct);
}

std::string ans = decode_sentences(text, case_pred, punct_pred);

return ans;
}

private:
OnlinePunctuationConfig config_;
OnlineCNNBiLSTMModel model_;
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-punctuation-impl.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/online-punctuation-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include "sherpa-onnx/csrc/online-punctuation-impl.h"

Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-punctuation-impl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-punctuation-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_

Expand Down
6 changes: 3 additions & 3 deletions sherpa-onnx/csrc/online-punctuation-model-config.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/online-punctuation-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include "sherpa-onnx/csrc/online-punctuation-model-config.h"

Expand Down Expand Up @@ -33,7 +33,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
}

if (!FileExists(cnn_bilstm)) {
SHERPA_ONNX_LOGE("--cnn-bilstm %s does not exist",
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist",
cnn_bilstm.c_str());
return false;
}
Expand All @@ -44,7 +44,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
}

if (!FileExists(bpe_vocab)) {
SHERPA_ONNX_LOGE("--bpe-vocab %s does not exist",
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist",
bpe_vocab.c_str());
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-punctuation-model-config.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-punctuation-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-punctuation.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// sherpa-onnx/csrc/online-punctuation.cc
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include "sherpa-onnx/csrc/online-punctuation.h"

Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/online-punctuation.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/online-punctuation.h
//
// Copyright (c) 2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_

Expand Down
9 changes: 6 additions & 3 deletions sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
// Copyright (c) 2024 Jian You ([email protected], Cisco Systems)

#include <stdio.h>
#include <iostream>

Expand All @@ -26,6 +27,7 @@ Please download the model from:
"how are you i am fine thank you"
The output text should look like below:
"How are you? I am fine. Thank you."
)usage";

sherpa_onnx::ParseOptions po(kUsageMessage);
Expand All @@ -34,7 +36,7 @@ The output text should look like below:
po.Read(argc, argv);
if (po.NumArgs() != 1) {
fprintf(stderr,
"Error: Please provide only 1 position argument containing the "
"Error: Please provide only 1 positional argument containing the "
"input text.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
Expand All @@ -55,8 +57,9 @@ The output text should look like below:
std::string text = po.GetArg(1);

std::string text_with_punct_case = punct.AddPunctuationWithCase(text);
fprintf(stderr, "Done\n");

const auto end = std::chrono::steady_clock::now();
fprintf(stderr, "Done\n");

float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
Expand Down

0 comments on commit 26bb94f

Please sign in to comment.