Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add tokenize-hotwords option #1039

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);

recognizer_config.tokenize_hotwords =
SHERPA_ONNX_OR(config->tokenize_hotwords, true);

if (config->hotwords_buf && config->hotwords_buf_size > 0) {
recognizer_config.hotwords_buf =
std::string(config->hotwords_buf, config->hotwords_buf_size);
Expand Down Expand Up @@ -467,6 +471,8 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig(
recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, "");
recognizer_config.hotwords_score =
SHERPA_ONNX_OR(config->hotwords_score, 1.5);
recognizer_config.tokenize_hotwords =
SHERPA_ONNX_OR(config->tokenize_hotwords, true);

recognizer_config.blank_penalty = config->blank_penalty;

Expand Down
7 changes: 7 additions & 0 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig {
/// Bonus score for each token in hotwords.
float hotwords_score;

/// Whether to tokenize hotwords
bool tokenize_hotwords;

SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config;
const char *rule_fsts;
const char *rule_fars;
Expand Down Expand Up @@ -438,6 +441,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig {

/// Bonus score for each token in hotwords.
float hotwords_score;

/// Whether to tokenize hotwords
bool tokenize_hotwords;

const char *rule_fsts;
const char *rule_fars;
float blank_penalty;
Expand Down
9 changes: 6 additions & 3 deletions sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
std::vector<std::vector<int32_t>> current;
std::vector<float> current_scores;
if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &current, &current_scores)) {
config_.tokenize_hotwords, bpe_encoder_.get(), &current,
&current_scores)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
Expand Down Expand Up @@ -262,7 +263,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand All @@ -286,7 +288,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, symbol_table_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
"The bonus score for each token in context word/phrase. "
"Used only when decoding_method is modified_beam_search");

po->Register(
"tokenize-hotwords", &tokenize_hotwords,
"Whether to tokenize hotwords, default true, if false the input hotwords "
"should be tokenized into tokens");

po->Register(
"rule-fsts", &rule_fsts,
"If not empty, it specifies fsts for inverse text normalization. "
Expand Down Expand Up @@ -125,6 +130,7 @@ std::string OfflineRecognizerConfig::ToString() const {
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", ";
os << "blank_penalty=" << blank_penalty << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\")";
Expand Down
13 changes: 9 additions & 4 deletions sherpa-onnx/csrc/offline-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ struct OfflineRecognizerConfig {

std::string hotwords_file;
float hotwords_score = 1.5;
/// Whether to tokenize the input hotwords, normally should be true
/// if false, you have to tokenize hotwords by yourself.
bool tokenize_hotwords = true;

float blank_penalty = 0.0;

Expand All @@ -56,7 +59,7 @@ struct OfflineRecognizerConfig {
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
const std::string &decoding_method, int32_t max_active_paths,
const std::string &hotwords_file, float hotwords_score,
float blank_penalty, const std::string &rule_fsts,
bool tokenize_hotwords, float blank_penalty, const std::string &rule_fsts,
const std::string &rule_fars)
: feat_config(feat_config),
model_config(model_config),
Expand All @@ -66,6 +69,7 @@ struct OfflineRecognizerConfig {
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
hotwords_score(hotwords_score),
tokenize_hotwords(tokenize_hotwords),
blank_penalty(blank_penalty),
rule_fsts(rule_fsts),
rule_fars(rule_fars) {}
Expand Down Expand Up @@ -94,9 +98,10 @@ class OfflineRecognizer {
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
* the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO
* WORLD. if tokenize_hotwords is false, the hotwords should be
* tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look
* like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
Expand Down
9 changes: 6 additions & 3 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
std::vector<std::vector<int32_t>> current;
std::vector<float> current_scores;
if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &current, &current_scores)) {
config_.tokenize_hotwords, bpe_encoder_.get(), &current,
&current_scores)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
Expand Down Expand Up @@ -420,7 +421,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand All @@ -444,7 +446,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

if (!EncodeHotwords(is, config_.model_config.modeling_unit, sym_,
bpe_encoder_.get(), &hotwords_, &boost_scores_)) {
config_.tokenize_hotwords, bpe_encoder_.get(),
&hotwords_, &boost_scores_)) {
SHERPA_ONNX_LOGE(
"Failed to encode some hotwords, skip them already, see logs above "
"for details.");
Expand Down
5 changes: 5 additions & 0 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"The file containing hotwords, one words/phrases per line, For example: "
"HELLO WORLD"
"你好世界");
po->Register(
"tokenize-hotwords", &tokenize_hotwords,
"Whether to tokenize hotwords, default true, if false the input hotwords "
"should be tokenized into tokens");
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
Expand Down Expand Up @@ -181,6 +185,7 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "max_active_paths=" << max_active_paths << ", ";
os << "hotwords_score=" << hotwords_score << ", ";
os << "hotwords_file=\"" << hotwords_file << "\", ";
os << "tokenize_hotwords=" << (tokenize_hotwords ? "True" : "False") << ", ";
os << "decoding_method=\"" << decoding_method << "\", ";
os << "blank_penalty=" << blank_penalty << ", ";
os << "temperature_scale=" << temperature_scale << ", ";
Expand Down
16 changes: 11 additions & 5 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ struct OnlineRecognizerConfig {
/// used only for modified_beam_search
std::string hotwords_file;
float hotwords_score = 1.5;
/// Whether to tokenize the input hotwords, normally should be true
/// if false, you have to tokenize hotwords by yourself.
bool tokenize_hotwords = true;

float blank_penalty = 0.0;

Expand All @@ -120,8 +123,9 @@ struct OnlineRecognizerConfig {
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
bool enable_endpoint, const std::string &decoding_method,
int32_t max_active_paths, const std::string &hotwords_file,
float hotwords_score, float blank_penalty, float temperature_scale,
const std::string &rule_fsts, const std::string &rule_fars)
float hotwords_score, bool tokenize_hotwords, float blank_penalty,
float temperature_scale, const std::string &rule_fsts,
const std::string &rule_fars)
: feat_config(feat_config),
model_config(model_config),
lm_config(lm_config),
Expand All @@ -132,6 +136,7 @@ struct OnlineRecognizerConfig {
max_active_paths(max_active_paths),
hotwords_file(hotwords_file),
hotwords_score(hotwords_score),
tokenize_hotwords(tokenize_hotwords),
blank_penalty(blank_penalty),
temperature_scale(temperature_scale),
rule_fsts(rule_fsts),
Expand Down Expand Up @@ -161,9 +166,10 @@ class OnlineRecognizer {
/** Create a stream for decoding.
*
* @param The hotwords for this string, it might contain several hotwords,
* the hotwords are separated by "/". In each of the hotwords, there
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
* For example, hotwords I LOVE YOU and HELLO WORLD, looks like:
* the hotwords are separated by "/". For eaxmple, I LOVE YOU/HELLO
* WORLD. if tokenize_hotwords is false, the hotwords should be
* tokenized, so hotwords I LOVE YOU and HELLO WORLD, should look
* like:
*
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
*/
Expand Down
25 changes: 18 additions & 7 deletions sherpa-onnx/csrc/text2token-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,23 @@ TEST(TEXT2TOKEN, TEST_cjkchar) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r = EncodeHotwords(iss, "cjkchar", sym_table, nullptr, &ids, &scores);
auto r =
EncodeHotwords(iss, "cjkchar", sym_table, true, nullptr, &ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{379, 380, 72, 874, 93, 1251, 489}, {262, 147, 3423, 2476, 21, 147}});
EXPECT_EQ(ids, expected_ids);
EXPECT_EQ(scores.size(), 0);

// tokenize_hotwords = false
text = "世 界 人 民 大 团 结\n中 国 V S 美 国\n\n"; // Test blank lines also

iss.clear();
iss.str(text);

r = EncodeHotwords(iss, "cjkchar", sym_table, false, nullptr, &ids, &scores);

EXPECT_EQ(ids, expected_ids);
EXPECT_EQ(scores.size(), 0);
}

Expand Down Expand Up @@ -79,8 +90,8 @@ TEST(TEXT2TOKEN, TEST_bpe) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r =
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(),
&ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{22, 58, 24, 425}, {19, 370, 47}});
Expand Down Expand Up @@ -117,8 +128,8 @@ TEST(TEXT2TOKEN, TEST_cjkchar_bpe) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, bpe_processor.get(),
&ids, &scores);
auto r = EncodeHotwords(iss, "cjkchar+bpe", sym_table, true,
bpe_processor.get(), &ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{1368, 1392, 557, 680, 275, 178, 475},
Expand Down Expand Up @@ -156,8 +167,8 @@ TEST(TEXT2TOKEN, TEST_bbpe) {
std::vector<std::vector<int32_t>> ids;
std::vector<float> scores;

auto r =
EncodeHotwords(iss, "bpe", sym_table, bpe_processor.get(), &ids, &scores);
auto r = EncodeHotwords(iss, "bpe", sym_table, true, bpe_processor.get(),
&ids, &scores);

std::vector<std::vector<int32_t>> expected_ids(
{{259, 1118, 234, 188, 132}, {259, 1585, 236, 161, 148, 236, 160, 191}});
Expand Down
13 changes: 12 additions & 1 deletion sherpa-onnx/csrc/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ static bool EncodeBase(const std::vector<std::string> &lines,
}
}
}
if (tmp_ids.empty()) {
continue;
}
ids->push_back(std::move(tmp_ids));
tmp_ids = {};
tmp_scores.push_back(score);
Expand Down Expand Up @@ -101,14 +104,22 @@ static bool EncodeBase(const std::vector<std::string> &lines,
}

bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const SymbolTable &symbol_table, bool tokenize_hotwords,
const ssentencepiece::Ssentencepiece *bpe_encoder,
std::vector<std::vector<int32_t>> *hotwords,
std::vector<float> *boost_scores) {
std::vector<std::string> lines;
std::string line;
std::string word;

if (!tokenize_hotwords) {
while (std::getline(is, line)) {
lines.push_back(line);
}
return EncodeBase(lines, symbol_table, hotwords, nullptr, boost_scores,
nullptr);
}

while (std::getline(is, line)) {
std::string score;
std::string phrase;
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace sherpa_onnx {
* otherwise returns false.
*/
bool EncodeHotwords(std::istream &is, const std::string &modeling_unit,
const SymbolTable &symbol_table,
const SymbolTable &symbol_table, bool tokenize_hotwords,
const ssentencepiece::Ssentencepiece *bpe_encoder_,
std::vector<std::vector<int32_t>> *hotwords_id,
std::vector<float> *boost_scores);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/jni/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "tokenizeHotwords", "Z");
ans.tokenize_hotwords = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
Expand Down
3 changes: 3 additions & 0 deletions sherpa-onnx/jni/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(cls, "hotwordsScore", "F");
ans.hotwords_score = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "tokenizeHotwords", "Z");
ans.tokenize_hotwords = env->GetFloatField(config, fid);

fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/kotlin-api/OfflineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ data class OfflineRecognizerConfig(
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
var tokenizeHotwords: Boolean = true,
var ruleFsts: String = "",
var ruleFars: String = "",
var blankPenalty: Float = 0.0f,
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ data class OnlineRecognizerConfig(
var maxActivePaths: Int = 4,
var hotwordsFile: String = "",
var hotwordsScore: Float = 1.5f,
var tokenizeHotwords: Boolean = true,
var ruleFsts: String = "",
var ruleFars: String = "",
var blankPenalty: Float = 0.0f,
Expand Down
8 changes: 5 additions & 3 deletions sherpa-onnx/python/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def(py::init<const FeatureExtractorConfig &, const OfflineModelConfig &,
const OfflineLMConfig &, const OfflineCtcFstDecoderConfig &,
const std::string &, int32_t, const std::string &, float,
float, const std::string &, const std::string &>(),
bool, float, const std::string &, const std::string &>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("lm_config") = OfflineLMConfig(),
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
py::arg("decoding_method") = "greedy_search",
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0,
py::arg("rule_fsts") = "", py::arg("rule_fars") = "")
py::arg("hotwords_score") = 1.5, py::arg("tokenize_hotwords") = true,
py::arg("blank_penalty") = 0.0, py::arg("rule_fsts") = "",
py::arg("rule_fars") = "")
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
Expand All @@ -33,6 +34,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
.def_readwrite("hotwords_score", &PyClass::hotwords_score)
.def_readwrite("tokenize_hotwords", &PyClass::tokenize_hotwords)
.def_readwrite("blank_penalty", &PyClass::blank_penalty)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
Expand Down
Loading
Loading