diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 2d0118833..ba6f9d162 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -104,6 +104,8 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + recognizer_config.tokenize_hotwords = + SHERPA_ONNX_OR(config->tokenize_hotwords, true); recognizer_config.ctc_fst_decoder_config.graph = SHERPA_ONNX_OR(config->ctc_fst_decoder_config.graph, ""); @@ -390,6 +392,8 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer( recognizer_config.hotwords_file = SHERPA_ONNX_OR(config->hotwords_file, ""); recognizer_config.hotwords_score = SHERPA_ONNX_OR(config->hotwords_score, 1.5); + recognizer_config.tokenize_hotwords = + SHERPA_ONNX_OR(config->tokenize_hotwords, true); recognizer_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, ""); recognizer_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, ""); diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index e9637ae7c..14fa03770 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -143,6 +143,9 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineRecognizerConfig { /// Bonus score for each token in hotwords. float hotwords_score; + /// Whether to tokenize hotwords + bool tokenize_hotwords; + SherpaOnnxOnlineCtcFstDecoderConfig ctc_fst_decoder_config; const char *rule_fsts; const char *rule_fars; @@ -413,6 +416,10 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerConfig { /// Bonus score for each token in hotwords. float hotwords_score; + + /// Whether to tokenize hotwords + bool tokenize_hotwords; + const char *rule_fsts; const char *rule_fars; } SherpaOnnxOfflineRecognizerConfig; diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 070d46f08..21a252bfd 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -34,6 +34,9 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "hotwordsScore", "F"); ans.hotwords_score = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "tokenizeHotwords", "Z"); + ans.tokenize_hotwords = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index d8acd0fed..03d92e48b 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -37,6 +37,9 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(cls, "hotwordsScore", "F"); ans.hotwords_score = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "tokenizeHotwords", "Z"); + ans.tokenize_hotwords = env->GetFloatField(config, fid); + fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(config, fid); p = env->GetStringUTFChars(s, nullptr); diff --git a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt index 7163d3d10..058873259 100644 --- a/sherpa-onnx/kotlin-api/OfflineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OfflineRecognizer.kt @@ -53,6 +53,7 @@ data class OfflineRecognizerConfig( var maxActivePaths: Int = 4, var hotwordsFile: String = "", var hotwordsScore: Float = 1.5f, + var tokenizeHotwords: Boolean = true, var ruleFsts: String = "", var ruleFars: String = "", ) diff --git a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt index de47a5ebd..5a6f1b214 100644 --- a/sherpa-onnx/kotlin-api/OnlineRecognizer.kt +++ b/sherpa-onnx/kotlin-api/OnlineRecognizer.kt @@ -69,6 +69,7 @@ data class OnlineRecognizerConfig( var maxActivePaths: Int = 4, var hotwordsFile: String = "", var hotwordsScore: Float = 1.5f, + var tokenizeHotwords: Boolean = true, var ruleFsts: String = "", var ruleFars: String = "", ) diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index f0e9a45f2..1f38e056f 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -49,6 +49,7 @@ def from_transducer( max_active_paths: int = 4, hotwords_file: str = "", hotwords_score: float = 1.5, + tokenize_hotwords: bool = True, blank_penalty: float = 0.0, modeling_unit: str = "cjkchar", bpe_vocab: str = "", @@ -96,6 +97,9 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + tokenize_hotwords: + Whether to tokenize hotwords, true will tokenize hotwords in the engine + if false, you have to tokenize hotwords by yourself. blank_penalty: The penalty applied on blank symbol during decoding. modeling_unit: @@ -165,6 +169,7 @@ def from_transducer( max_active_paths=max_active_paths, hotwords_file=hotwords_file, hotwords_score=hotwords_score, + tokenize_hotwords=tokenize_hotwords, blank_penalty=blank_penalty, rule_fsts=rule_fsts, rule_fars=rule_fars, diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 82b2e3b42..f8936dfd4 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -54,6 +54,7 @@ def from_transducer( decoding_method: str = "greedy_search", max_active_paths: int = 4, hotwords_score: float = 1.5, + tokenize_hotwords: bool = True, blank_penalty: float = 0.0, hotwords_file: str = "", provider: str = "cpu", @@ -131,6 +132,9 @@ def from_transducer( hotwords_score: The hotword score of each token for biasing word/phrase. Used only if hotwords_file is given with modified_beam_search as decoding method. + tokenize_hotwords: + Whether to tokenize hotwords, true will tokenize hotwords in the engine + if false, you have to tokenize hotwords by yourself. temperature_scale: Temperature scaling for output symbol confidence estiamation. It affects only confidence values, the decoding uses the original @@ -222,6 +226,7 @@ def from_transducer( decoding_method=decoding_method, max_active_paths=max_active_paths, hotwords_score=hotwords_score, + tokenize_hotwords=tokenize_hotwords, hotwords_file=hotwords_file, blank_penalty=blank_penalty, temperature_scale=temperature_scale,