Skip to content

Commit

Permalink
Add tokenize-hotwords option
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jun 21, 2024
1 parent 96ab843 commit 12e1d4c
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 41 deletions.
9 changes: 6 additions & 3 deletions sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
std::vector<std::vector<int32_t>> current;
std::vector<float> current_scores;
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &current, &current_scores)) {
config_.tokenize_hotwords, bpe_encoder_.get(), &current,
&current_scores)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
Expand Down Expand Up @@ -257,7 +258,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand All @@ -281,7 +283,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");

po->Register(
"tokenize-hotwords", &tokenize_hotwords,
"Whether to tokenize hotwords, default true, if false the input hotwords "
"should be tokenized into tokens");

po->Register(
"rule-fsts", &rule_fsts,
"If not empty, it specifies fsts for inverse text normalization. "
Expand Down Expand Up @@ -125,6 +130,7 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", ";
os << "blank_penalty=" << blank_penalty << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
Expand Down
13 changes: 9 additions & 4 deletions sherpa-onnx/csrc/offline-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ struct OfflineRecognizerConfig {

std::string hotwords_file;
float hotwords_score = 1.5;
/// Whether to tokenize the input hotwords, normally should be true
/// if false, you have to tokenize hotwords by yourself.
bool tokenize_hotwords = true;

float blank_penalty = 0.0;

Expand All @@ -56,7 +59,7 @@ struct OfflineRecognizerConfig {
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score,
float blank_penalty, const std::string &rule_fsts,
bool tokenize_hotwords, float blank_penalty, const std::string &rule_fsts,
const std::string &rule_fars)
: feat_config(feat_config),
model_config(model_config),
Expand All @@ -66,6 +69,7 @@ struct OfflineRecognizerConfig {
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
hotwords_score(hotwords_score),
tokenize_hotwords(tokenize_hotwords),
blank_penalty(blank_penalty),
rule_fsts(rule_fsts),
rule_fars(rule_fars) {}
Expand Down Expand Up @@ -94,9 +98,10 @@ class OfflineRecognizer {
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
* the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO
* WORLD. if tokenize_hotwords is false, the hotwords should be
* tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look
* like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
Expand Down
9 changes: 6 additions & 3 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std::vector<std::vector<int32_t>> current;
std::vector<float> current_scores;
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &current, &current_scores)) {
config_.tokenize_hotwords, bpe_encoder_.get(), &current,
&current_scores)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
Expand Down Expand Up @@ -401,7 +402,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand All @@ -425,7 +427,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register(
"tokenize-hotwords", &tokenize_hotwords,
"Whether to tokenize hotwords, default true, if false the input hotwords "
"should be tokenized into tokens");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
Expand Down Expand Up @@ -183,6 +187,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ", ";
Expand Down
16 changes: 11 additions & 5 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ struct OnlineRecognizerConfig {
/// used only for modified_beam_search
std::string hotwords_file;
float hotwords_score = 1.5;
/// Whether to tokenize the input hotwords, normally should be true
/// if false, you have to tokenize hotwords by yourself.
bool tokenize_hotwords = true;

float blank_penalty = 0.0;

Expand All @@ -115,8 +118,9 @@ struct OnlineRecognizerConfig {
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty, float temperature_scale,
const std::string &rule_fsts, const std::string &rule_fars)
float hotwords_score, bool tokenize_hotwords, float blank_penalty,
float temperature_scale, const std::string &rule_fsts,
const std::string &rule_fars)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
Expand All @@ -127,6 +131,7 @@ struct OnlineRecognizerConfig {
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
hotwords_score(hotwords_score),
tokenize_hotwords(tokenize_hotwords),
blank_penalty(blank_penalty),
temperature_scale(temperature_scale),
rule_fsts(rule_fsts),
Expand Down Expand Up @@ -156,9 +161,10 @@ class OnlineRecognizer {
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
* the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO
* WORLD. if tokenize_hotwords is false, the hotwords should be
* tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look
* like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
Expand Down
25 changes: 18 additions & 7 deletions sherpa-onnx/csrc/text2token-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores);
auto r =
EncodeHotwords(iss, "cjkchar", sym_table, true, nullptr, &ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
EXPECT_EQ(ids, expected_ids);
EXPECT_EQ(scores.size(), 0);

// tokenize_hotwords = false
text = "世 界 人 民 大 团 结\n中 国 V S 美 国\n\n"; // Test blank lines also

iss.clear();
iss.str(text);

r = EncodeHotwords(iss, "cjkchar", sym_table, false, nullptr, &ids, &scores);

EXPECT_EQ(ids, expected_ids);
EXPECT_EQ(scores.size(), 0);
}

Expand Down Expand Up @@ -79,8 +90,8 @@ TEST(TEXT2TOKEN, TEST_bpe) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r =
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(),
&ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{22, 58, 24, 425}, {19, 370, 47}});
Expand Down Expand Up @@ -117,8 +128,8 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(),
&ids, &scores);
auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, true,
bpe_processor.get(), &ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{1368, 1392, 557, 680, 275, 178, 475},
Expand Down Expand Up @@ -156,8 +167,8 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r =
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(),
&ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
Expand Down
13 changes: 12 additions & 1 deletion sherpa-onnx/csrc/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ static bool EncodeBase(const std::vector<std::string> &lines,
}
}
}
if (tmp_ids.empty()) {
continue;
}
ids->push_back(std::move(tmp_ids));
tmp_ids = {};
tmp_scores.push_back(score);
Expand Down Expand Up @@ -101,14 +104,22 @@ static bool EncodeBase(const std::vector<std::string> &lines,
}

bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const SymbolTable &symbol_table, bool tokenize_hotwords,
const ssentencepiece::Ssentencepiece *bpe_encoder,
std::vector<std::vector<int32_t>> *hotwords,
std::vector<float> *boost_scores) {
std::vector<std::string> lines;
std::string line;
std::string word;

if (!tokenize_hotwords) {
while (std::getline(is, line)) {
lines.push_back(line);
}
return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores,
nullptr);
}

while (std::getline(is, line)) {
std::string score;
std::string phrase;
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace sherpa_onnx {
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const SymbolTable &symbol_table, bool tokenize_hotwords,
const ssentencepiece::Ssentencepiece *bpe_encoder_,
std::vector<std::vector<int32_t>> *hotwords_id,
std::vector<float> *boost_scores);
Expand Down
8 changes: 5 additions & 3 deletions sherpa-onnx/python/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
const std::string &, int32_t, const std::string &, float,
float, const std::string &, const std::string &>(),
bool, float, const std::string &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
py::arg("hotwords_score") = 1.5, py::arg("tokenize_hotwords") = true,
py::arg("blank_penalty") = 0.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "")
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
Expand All @@ -33,6 +34,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
Expand Down
30 changes: 16 additions & 14 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,21 @@ static void PybindOnlineRecognizerResult(py::module *m) {
static void PybindOnlineRecognizerConfig(py::module *m) {
using PyClass = OnlineRecognizerConfig;
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool,
const std::string &, int32_t, const std::string &, float,
float, float, const std::string &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0,
py::arg("temperature_scale") = 2.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "")
.def(
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
const OnlineLMConfig &, const EndpointConfig &,
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
int32_t, const std::string &, float, bool, float, float,
const std::string &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OnlineLMConfig(),
py::arg("endpoint_config") = EndpointConfig(),
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
py::arg("enable_endpoint"), py::arg("decoding_method"),
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 0, py::arg("tokenize_hotwords") = true,
py::arg("blank_penalty") = 0.0, py::arg("temperature_scale") = 2.0,
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
Expand All @@ -78,6 +79,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("temperature_scale", &PyClass::temperature_scale)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
Expand Down

0 comments on commit 12e1d4c

Please sign in to comment.