From 960eb7529e9212f65a482ba248fa7c2a16926aeb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 16 Jul 2024 15:55:02 +0800 Subject: [PATCH] Add C++ runtime for MeloTTS (#1138) --- .../workflows/export-melo-tts-to-onnx.yaml | 10 + .github/workflows/windows-x64-jni.yaml | 10 +- CHANGELOG.md | 4 + CMakeLists.txt | 2 +- .../non-streaming-asr/pubspec.yaml | 2 +- dart-api-examples/streaming-asr/pubspec.yaml | 2 +- dart-api-examples/tts/pubspec.yaml | 2 +- dart-api-examples/vad/pubspec.yaml | 2 +- flutter-examples/streaming_asr/pubspec.yaml | 4 +- flutter-examples/tts/pubspec.yaml | 4 +- flutter/sherpa_onnx/pubspec.yaml | 12 +- .../ios/sherpa_onnx_ios.podspec | 2 +- .../macos/sherpa_onnx_macos.podspec | 2 +- nodejs-addon-examples/package.json | 2 +- scripts/apk/build-apk-tts-engine.sh.in | 4 + scripts/apk/build-apk-tts.sh.in | 4 + scripts/apk/generate-tts-apk-script.py | 18 +- scripts/dart/sherpa-onnx-pubspec.yaml | 2 +- scripts/flutter/generate-tts.py | 35 +-- scripts/melo-tts/README.md | 6 + scripts/melo-tts/export-onnx.py | 4 +- scripts/melo-tts/run.sh | 2 +- scripts/melo-tts/test.py | 19 +- sherpa-onnx/c-api/c-api.cc | 8 +- sherpa-onnx/c-api/c-api.h | 5 +- sherpa-onnx/csrc/CMakeLists.txt | 2 + sherpa-onnx/csrc/cppjieba-test.cc | 2 +- sherpa-onnx/csrc/jieba-lexicon.cc | 8 +- sherpa-onnx/csrc/jieba-lexicon.h | 13 +- sherpa-onnx/csrc/lexicon.cc | 10 +- sherpa-onnx/csrc/lexicon.h | 6 +- sherpa-onnx/csrc/melo-tts-lexicon.cc | 266 ++++++++++++++++++ sherpa-onnx/csrc/melo-tts-lexicon.h | 36 +++ .../csrc/offline-tts-character-frontend.cc | 5 +- .../csrc/offline-tts-character-frontend.h | 2 +- sherpa-onnx/csrc/offline-tts-frontend.cc | 34 +++ sherpa-onnx/csrc/offline-tts-frontend.h | 22 +- sherpa-onnx/csrc/offline-tts-vits-impl.h | 89 +++++- .../csrc/offline-tts-vits-model-metadata.h | 5 + sherpa-onnx/csrc/offline-tts-vits-model.cc | 84 ++++++ sherpa-onnx/csrc/offline-tts-vits-model.h | 4 + sherpa-onnx/csrc/offline-whisper-decoder.h | 3 +- .../offline-whisper-greedy-search-decoder.cc | 7 +- sherpa-onnx/csrc/onnx-utils.cc | 12 +- sherpa-onnx/csrc/onnx-utils.h | 1 + sherpa-onnx/csrc/piper-phonemize-lexicon.cc | 4 +- sherpa-onnx/csrc/piper-phonemize-lexicon.h | 2 +- sherpa-onnx/csrc/session.cc | 58 ++-- sherpa-onnx/csrc/session.h | 4 +- sherpa-onnx/csrc/speaker-embedding-manager.cc | 1 + sherpa-onnx/csrc/utfcpp-test.cc | 2 +- 51 files changed, 693 insertions(+), 156 deletions(-) create mode 100644 scripts/melo-tts/README.md create mode 100644 sherpa-onnx/csrc/melo-tts-lexicon.cc create mode 100644 sherpa-onnx/csrc/melo-tts-lexicon.h create mode 100644 sherpa-onnx/csrc/offline-tts-frontend.cc diff --git a/.github/workflows/export-melo-tts-to-onnx.yaml b/.github/workflows/export-melo-tts-to-onnx.yaml index 654949c79..4e561688f 100644 --- a/.github/workflows/export-melo-tts-to-onnx.yaml +++ b/.github/workflows/export-melo-tts-to-onnx.yaml @@ -63,10 +63,16 @@ jobs: echo "pwd: $PWD" ls -lh ../scripts/melo-tts + rm -rf ./ + cp -v ../scripts/melo-tts/*.onnx . cp -v ../scripts/melo-tts/lexicon.txt . cp -v ../scripts/melo-tts/tokens.txt . + cp -v ../scripts/melo-tts/README.md . + + curl -SL -O https://raw.githubusercontent.com/myshell-ai/MeloTTS/main/LICENSE + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/new_heteronym.fst curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst @@ -77,6 +83,10 @@ jobs: git lfs track "*.onnx" git add . + ls -lh + + git status + git commit -m "add models" git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/vits-melo-tts-zh_en main || true diff --git a/.github/workflows/windows-x64-jni.yaml b/.github/workflows/windows-x64-jni.yaml index 481edbb58..a6ab2a0f6 100644 --- a/.github/workflows/windows-x64-jni.yaml +++ b/.github/workflows/windows-x64-jni.yaml @@ -39,10 +39,14 @@ jobs: cd build cmake \ -A x64 \ - -D CMAKE_BUILD_TYPE=Release \ - -D BUILD_SHARED_LIBS=ON \ + -DBUILD_SHARED_LIBS=ON \ -D SHERPA_ONNX_ENABLE_JNI=ON \ - -D CMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ + -DBUILD_ESPEAK_NG_EXE=OFF \ + -DSHERPA_ONNX_BUILD_C_API_EXAMPLES=OFF \ + -DSHERPA_ONNX_ENABLE_BINARY=ON \ .. - name: Build sherpa-onnx for windows diff --git a/CHANGELOG.md b/CHANGELOG.md index 6beff725b..b685f7b66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 1.10.16 + +* Support zh-en TTS model from MeloTTS. + ## 1.10.15 * Downgrade onnxruntime from v1.18.1 to v1.17.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 231b62b1e..2e2d1b2b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ project(sherpa-onnx) # ./nodejs-addon-examples # ./dart-api-examples/ # ./CHANGELOG.md -set(SHERPA_ONNX_VERSION "1.10.15") +set(SHERPA_ONNX_VERSION "1.10.16") # Disable warning about # diff --git a/dart-api-examples/non-streaming-asr/pubspec.yaml b/dart-api-examples/non-streaming-asr/pubspec.yaml index 277c3445e..b916bee4c 100644 --- a/dart-api-examples/non-streaming-asr/pubspec.yaml +++ b/dart-api-examples/non-streaming-asr/pubspec.yaml @@ -10,7 +10,7 @@ environment: # Add regular dependencies here. dependencies: - sherpa_onnx: ^1.10.15 + sherpa_onnx: ^1.10.16 path: ^1.9.0 args: ^2.5.0 diff --git a/dart-api-examples/streaming-asr/pubspec.yaml b/dart-api-examples/streaming-asr/pubspec.yaml index eca97389f..6722c1804 100644 --- a/dart-api-examples/streaming-asr/pubspec.yaml +++ b/dart-api-examples/streaming-asr/pubspec.yaml @@ -11,7 +11,7 @@ environment: # Add regular dependencies here. dependencies: - sherpa_onnx: ^1.10.15 + sherpa_onnx: ^1.10.16 path: ^1.9.0 args: ^2.5.0 diff --git a/dart-api-examples/tts/pubspec.yaml b/dart-api-examples/tts/pubspec.yaml index b7bb3c285..3383c983a 100644 --- a/dart-api-examples/tts/pubspec.yaml +++ b/dart-api-examples/tts/pubspec.yaml @@ -8,7 +8,7 @@ environment: # Add regular dependencies here. dependencies: - sherpa_onnx: ^1.10.15 + sherpa_onnx: ^1.10.16 path: ^1.9.0 args: ^2.5.0 diff --git a/dart-api-examples/vad/pubspec.yaml b/dart-api-examples/vad/pubspec.yaml index 1249bf3f2..81c2a8588 100644 --- a/dart-api-examples/vad/pubspec.yaml +++ b/dart-api-examples/vad/pubspec.yaml @@ -9,7 +9,7 @@ environment: sdk: ^3.4.0 dependencies: - sherpa_onnx: ^1.10.15 + sherpa_onnx: ^1.10.16 path: ^1.9.0 args: ^2.5.0 diff --git a/flutter-examples/streaming_asr/pubspec.yaml b/flutter-examples/streaming_asr/pubspec.yaml index 43593e97d..20188a053 100644 --- a/flutter-examples/streaming_asr/pubspec.yaml +++ b/flutter-examples/streaming_asr/pubspec.yaml @@ -5,7 +5,7 @@ description: > publish_to: 'none' -version: 1.10.14 +version: 1.10.16 topics: - speech-recognition @@ -30,7 +30,7 @@ dependencies: record: ^5.1.0 url_launcher: ^6.2.6 - sherpa_onnx: ^1.10.15 + sherpa_onnx: ^1.10.16 # sherpa_onnx: # path: ../../flutter/sherpa_onnx diff --git a/flutter-examples/tts/pubspec.yaml b/flutter-examples/tts/pubspec.yaml index d19918ae5..776c96063 100644 --- a/flutter-examples/tts/pubspec.yaml +++ b/flutter-examples/tts/pubspec.yaml @@ -5,7 +5,7 @@ description: > publish_to: 'none' # Remove this line if you wish to publish to pub.dev -version: 1.0.0 +version: 1.10.16 environment: sdk: '>=3.4.0 <4.0.0' @@ -17,7 +17,7 @@ dependencies: cupertino_icons: ^1.0.6 path_provider: ^2.1.3 path: ^1.9.0 - sherpa_onnx: ^1.10.15 + sherpa_onnx: ^1.10.16 url_launcher: ^6.2.6 audioplayers: ^5.0.0 diff --git a/flutter/sherpa_onnx/pubspec.yaml b/flutter/sherpa_onnx/pubspec.yaml index 543635fba..d4833805e 100644 --- a/flutter/sherpa_onnx/pubspec.yaml +++ b/flutter/sherpa_onnx/pubspec.yaml @@ -17,7 +17,7 @@ topics: - voice-activity-detection # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec -version: 1.10.15 +version: 1.10.16 homepage: https://github.com/k2-fsa/sherpa-onnx @@ -30,19 +30,19 @@ dependencies: flutter: sdk: flutter - sherpa_onnx_android: ^1.10.15 + sherpa_onnx_android: ^1.10.16 # path: ../sherpa_onnx_android - sherpa_onnx_macos: ^1.10.15 + sherpa_onnx_macos: ^1.10.16 # path: ../sherpa_onnx_macos - sherpa_onnx_linux: ^1.10.15 + sherpa_onnx_linux: ^1.10.16 # path: ../sherpa_onnx_linux # - sherpa_onnx_windows: ^1.10.15 + sherpa_onnx_windows: ^1.10.16 # path: ../sherpa_onnx_windows - sherpa_onnx_ios: ^1.10.15 + sherpa_onnx_ios: ^1.10.16 # sherpa_onnx_ios: # path: ../sherpa_onnx_ios diff --git a/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec b/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec index 4ef367f57..df087dcec 100644 --- a/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec +++ b/flutter/sherpa_onnx_ios/ios/sherpa_onnx_ios.podspec @@ -7,7 +7,7 @@ # https://groups.google.com/g/dart-ffi/c/nUATMBy7r0c Pod::Spec.new do |s| s.name = 'sherpa_onnx_ios' - s.version = '1.10.15' + s.version = '1.10.16' s.summary = 'A new Flutter FFI plugin project.' s.description = <<-DESC A new Flutter FFI plugin project. diff --git a/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec b/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec index 70069a655..0b7e60c3a 100644 --- a/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec +++ b/flutter/sherpa_onnx_macos/macos/sherpa_onnx_macos.podspec @@ -4,7 +4,7 @@ # Pod::Spec.new do |s| s.name = 'sherpa_onnx_macos' - s.version = '1.10.15' + s.version = '1.10.16' s.summary = 'sherpa-onnx Flutter FFI plugin project.' s.description = <<-DESC sherpa-onnx Flutter FFI plugin project. diff --git a/nodejs-addon-examples/package.json b/nodejs-addon-examples/package.json index 9c449c083..dd7cfed09 100644 --- a/nodejs-addon-examples/package.json +++ b/nodejs-addon-examples/package.json @@ -1,5 +1,5 @@ { "dependencies": { - "sherpa-onnx-node": "^1.10.15" + "sherpa-onnx-node": "^1.10.16" } } diff --git a/scripts/apk/build-apk-tts-engine.sh.in b/scripts/apk/build-apk-tts-engine.sh.in index 902f6f477..c611c061b 100644 --- a/scripts/apk/build-apk-tts-engine.sh.in +++ b/scripts/apk/build-apk-tts-engine.sh.in @@ -78,6 +78,10 @@ sed -i.bak s/"lang = null"/"lang = \"$lang_iso_639_3\""/ ./TtsEngine.kt git diff popd +if [[ $model_dir == vits-melo-tts-zh_en ]]; then + lang=zh_en +fi + for arch in arm64-v8a armeabi-v7a x86_64 x86; do log "------------------------------------------------------------" log "build tts apk for $arch" diff --git a/scripts/apk/build-apk-tts.sh.in b/scripts/apk/build-apk-tts.sh.in index 73139790f..2e62ad636 100644 --- a/scripts/apk/build-apk-tts.sh.in +++ b/scripts/apk/build-apk-tts.sh.in @@ -76,6 +76,10 @@ sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt git diff popd +if [[ $model_dir == vits-melo-tts-zh_en ]]; then + lang=zh_en +fi + for arch in arm64-v8a armeabi-v7a x86_64 x86; do log "------------------------------------------------------------" log "build tts apk for $arch" diff --git a/scripts/apk/generate-tts-apk-script.py b/scripts/apk/generate-tts-apk-script.py index e7d9911bc..48745c312 100755 --- a/scripts/apk/generate-tts-apk-script.py +++ b/scripts/apk/generate-tts-apk-script.py @@ -312,6 +312,11 @@ def get_vits_models() -> List[TtsModel]: model_name="vits-zh-hf-fanchen-wnj.onnx", lang="zh", ), + TtsModel( + model_dir="vits-melo-tts-zh_en", + model_name="model.onnx", + lang="zh", + ), TtsModel( model_dir="vits-zh-hf-fanchen-C", model_name="vits-zh-hf-fanchen-C.onnx", @@ -339,18 +344,21 @@ def get_vits_models() -> List[TtsModel]: ), ] - rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] + rule_fsts = ["phone.fst", "date.fst", "number.fst"] for m in chinese_models: s = [f"{m.model_dir}/{r}" for r in rule_fsts] - if "vits-zh-hf" in m.model_dir or "sherpa-onnx-vits-zh-ll" == m.model_dir: + if ( + "vits-zh-hf" in m.model_dir + or "sherpa-onnx-vits-zh-ll" == m.model_dir + or "melo-tts" in m.model_dir + ): s = s[:-1] m.dict_dir = m.model_dir + "/dict" + else: + m.rule_fars = f"{m.model_dir}/rule.far" m.rule_fsts = ",".join(s) - if "vits-zh-hf" not in m.model_dir and "zh-ll" not in m.model_dir: - m.rule_fars = f"{m.model_dir}/rule.far" - all_models = chinese_models + [ TtsModel( model_dir="vits-cantonese-hf-xiaomaiiwn", diff --git a/scripts/dart/sherpa-onnx-pubspec.yaml b/scripts/dart/sherpa-onnx-pubspec.yaml index 0633680bc..a0aeb0e5c 100644 --- a/scripts/dart/sherpa-onnx-pubspec.yaml +++ b/scripts/dart/sherpa-onnx-pubspec.yaml @@ -17,7 +17,7 @@ topics: - voice-activity-detection # remember to change the version in ../sherpa_onnx_macos/macos/sherpa_onnx.podspec -version: 1.10.15 +version: 1.10.16 homepage: https://github.com/k2-fsa/sherpa-onnx diff --git a/scripts/flutter/generate-tts.py b/scripts/flutter/generate-tts.py index 96b646bdf..4b5f8d5d0 100755 --- a/scripts/flutter/generate-tts.py +++ b/scripts/flutter/generate-tts.py @@ -6,9 +6,6 @@ import jinja2 -# pip install iso639-lang -from iso639 import Lang - def get_args(): parser = argparse.ArgumentParser() @@ -37,13 +34,6 @@ class TtsModel: data_dir: Optional[str] = None dict_dir: Optional[str] = None is_char: bool = False - lang_iso_639_3: str = "" - - -def convert_lang_to_iso_639_3(models: List[TtsModel]): - for m in models: - if m.lang_iso_639_3 == "": - m.lang_iso_639_3 = Lang(m.lang).pt3 def get_coqui_models() -> List[TtsModel]: @@ -312,6 +302,11 @@ def get_vits_models() -> List[TtsModel]: model_name="vits-zh-hf-fanchen-wnj.onnx", lang="zh", ), + TtsModel( + model_dir="vits-melo-tts-zh_en", + model_name="model.onnx", + lang="zh_en", + ), TtsModel( model_dir="vits-zh-hf-fanchen-C", model_name="vits-zh-hf-fanchen-C.onnx", @@ -332,26 +327,33 @@ def get_vits_models() -> List[TtsModel]: model_name="vits-zh-hf-fanchen-unity.onnx", lang="zh", ), + TtsModel( + model_dir="sherpa-onnx-vits-zh-ll", + model_name="model.onnx", + lang="zh", + ), ] - rule_fsts = ["phone.fst", "date.fst", "number.fst", "new_heteronym.fst"] + rule_fsts = ["phone.fst", "date.fst", "number.fst"] for m in chinese_models: s = [f"{m.model_dir}/{r}" for r in rule_fsts] - if "vits-zh-hf" in m.model_dir: + if ( + "vits-zh-hf" in m.model_dir + or "sherpa-onnx-vits-zh-ll" == m.model_dir + or "melo-tts" in m.model_dir + ): s = s[:-1] m.dict_dir = m.model_dir + "/dict" + else: + m.rule_fars = f"{m.model_dir}/rule.far" m.rule_fsts = ",".join(s) - if "vits-zh-hf" not in m.model_dir: - m.rule_fars = f"{m.model_dir}/rule.far" - all_models = chinese_models + [ TtsModel( model_dir="vits-cantonese-hf-xiaomaiiwn", model_name="vits-cantonese-hf-xiaomaiiwn.onnx", lang="cantonese", - lang_iso_639_3="yue", rule_fsts="vits-cantonese-hf-xiaomaiiwn/rule.fst", ), # English (US) @@ -374,7 +376,6 @@ def main(): all_model_list += get_piper_models() all_model_list += get_mimic3_models() all_model_list += get_coqui_models() - convert_lang_to_iso_639_3(all_model_list) num_models = len(all_model_list) diff --git a/scripts/melo-tts/README.md b/scripts/melo-tts/README.md new file mode 100644 index 000000000..802af0608 --- /dev/null +++ b/scripts/melo-tts/README.md @@ -0,0 +1,6 @@ +# Introduction + +Models in this directory are converted from +https://github.com/myshell-ai/MeloTTS + +Note there is only a single female speaker in the model. diff --git a/scripts/melo-tts/export-onnx.py b/scripts/melo-tts/export-onnx.py index 81a261c58..ff25d8c8b 100755 --- a/scripts/melo-tts/export-onnx.py +++ b/scripts/melo-tts/export-onnx.py @@ -8,7 +8,6 @@ from melo.text.chinese import pinyin_to_symbol_map from melo.text.english import eng_dict, refine_syllables from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict -from melo.text.symbols import language_tone_start_map for k, v in pinyin_to_symbol_map.items(): if isinstance(v, list): @@ -82,6 +81,7 @@ def generate_tokens(symbol_list): def generate_lexicon(): word_dict = pinyin_dict.pinyin_dict phrases = phrases_dict.phrases_dict + eng_dict["kaldi"] = [["K", "AH0"], ["L", "D", "IH0"]] with open("lexicon.txt", "w", encoding="utf-8") as f: for word in eng_dict: phones, tones = refine_syllables(eng_dict[word]) @@ -237,9 +237,11 @@ def main(): meta_data = { "model_type": "melo-vits", "comment": "melo", + "version": 2, "language": "Chinese + English", "add_blank": int(model.hps.data.add_blank), "n_speakers": 1, + "jieba": 1, "sample_rate": model.hps.data.sampling_rate, "bert_dim": 1024, "ja_bert_dim": 768, diff --git a/scripts/melo-tts/run.sh b/scripts/melo-tts/run.sh index 3af6ba013..eea3de897 100755 --- a/scripts/melo-tts/run.sh +++ b/scripts/melo-tts/run.sh @@ -12,7 +12,7 @@ function install() { cd MeloTTS pip install -r ./requirements.txt - pip install soundfile onnx onnxruntime + pip install soundfile onnx==1.15.0 onnxruntime==1.16.3 python3 -m unidic download popd diff --git a/scripts/melo-tts/test.py b/scripts/melo-tts/test.py index 4d97437ae..c5b808044 100755 --- a/scripts/melo-tts/test.py +++ b/scripts/melo-tts/test.py @@ -135,28 +135,11 @@ def __call__(self, x, tones): def main(): lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt") - text = "永远相信,美好的事情即将发生。" + text = "这是一个使用 next generation kaldi 的 text to speech 中英文例子. Thank you! 你觉得如何呢? are you ok? Fantastic! How about you?" s = jieba.cut(text, HMM=True) phones, tones = lexicon.convert(s) - en_text = "how are you ?".split() - - phones_en, tones_en = lexicon.convert(en_text) - phones += [0] - tones += [0] - - phones += phones_en - tones += tones_en - - text = "多音字测试, 银行,行不行?长沙长大" - s = jieba.cut(text, HMM=True) - - phones2, tones2 = lexicon.convert(s) - - phones += phones2 - tones += tones2 - model = OnnxModel("./model.onnx") if model.add_blank: diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index cda5832e2..c63fad900 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -422,10 +422,10 @@ sherpa_onnx::OfflineRecognizerConfig convertConfig( void SherpaOnnxOfflineRecognizerSetConfig( const SherpaOnnxOfflineRecognizer *recognizer, - const SherpaOnnxOfflineRecognizerConfig *config){ + const SherpaOnnxOfflineRecognizerConfig *config) { sherpa_onnx::OfflineRecognizerConfig recognizer_config = convertConfig(config); - recognizer->impl->SetConfig(recognizer_config); + recognizer->impl->SetConfig(recognizer_config); } void DestroyOfflineRecognizer(SherpaOnnxOfflineRecognizer *recognizer) { @@ -478,7 +478,7 @@ const SherpaOnnxOfflineRecognizerResult *GetOfflineStreamResult( pText[text.size()] = 0; r->text = pText; - //lang + // lang const auto &lang = result.lang; char *c_lang = new char[lang.size() + 1]; std::copy(lang.begin(), lang.end(), c_lang); @@ -1317,7 +1317,7 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches( } delete[] r->matches; delete r; -}; +} int32_t SherpaOnnxSpeakerEmbeddingManagerVerify( const SherpaOnnxSpeakerEmbeddingManager *p, const char *name, diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index e6d8ae272..13ea4f5c4 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -496,7 +496,7 @@ SHERPA_ONNX_API void DecodeMultipleOfflineStreams( SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { const char *text; - // Pointer to continuous memory which holds timestamps + // Pointer to continuous memory which holds timestamps // // It is NULL if the model does not support timestamps float *timestamps; @@ -525,9 +525,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineRecognizerResult { */ const char *json; - //return recognized language + // return recognized language const char *lang; - } SherpaOnnxOfflineRecognizerResult; /// Get the result of the offline stream. diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index b6bda8ba9..89d3d278e 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -142,7 +142,9 @@ if(SHERPA_ONNX_ENABLE_TTS) list(APPEND sources jieba-lexicon.cc lexicon.cc + melo-tts-lexicon.cc offline-tts-character-frontend.cc + offline-tts-frontend.cc offline-tts-impl.cc offline-tts-model-config.cc offline-tts-vits-model-config.cc diff --git a/sherpa-onnx/csrc/cppjieba-test.cc b/sherpa-onnx/csrc/cppjieba-test.cc index 77a856e2e..515a90ad7 100644 --- a/sherpa-onnx/csrc/cppjieba-test.cc +++ b/sherpa-onnx/csrc/cppjieba-test.cc @@ -33,7 +33,7 @@ TEST(CppJieBa, Case1) { std::vector words; std::vector jiebawords; - std::string s = "他来到了网易杭研大厦"; + std::string s = "他来到了网易杭研大厦。How are you?"; std::cout << s << std::endl; std::cout << "[demo] Cut With HMM" << std::endl; jieba.Cut(s, words, true); diff --git a/sherpa-onnx/csrc/jieba-lexicon.cc b/sherpa-onnx/csrc/jieba-lexicon.cc index 1bf64cd50..82b45cdfe 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.cc +++ b/sherpa-onnx/csrc/jieba-lexicon.cc @@ -17,6 +17,7 @@ namespace sherpa_onnx { // implemented in ./lexicon.cc std::unordered_map ReadTokens(std::istream &is); + std::vector ConvertTokensToIds( const std::unordered_map &token2id, const std::vector &tokens); @@ -53,8 +54,7 @@ class JiebaLexicon::Impl { } } - std::vector> ConvertTextToTokenIds( - const std::string &text) const { + std::vector ConvertTextToTokenIds(const std::string &text) const { // see // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 std::regex punct_re{":|、|;"}; @@ -87,7 +87,7 @@ class JiebaLexicon::Impl { SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); } - std::vector> ans; + std::vector ans; std::vector this_sentence; int32_t blank = token2id_.at(" "); @@ -217,7 +217,7 @@ JiebaLexicon::JiebaLexicon(const std::string &lexicon, : impl_(std::make_unique(lexicon, tokens, dict_dir, meta_data, debug)) {} -std::vector> JiebaLexicon::ConvertTextToTokenIds( +std::vector JiebaLexicon::ConvertTextToTokenIds( const std::string &text, const std::string & /*unused_voice = ""*/) const { return impl_->ConvertTextToTokenIds(text); } diff --git a/sherpa-onnx/csrc/jieba-lexicon.h b/sherpa-onnx/csrc/jieba-lexicon.h index 867fa01aa..d02e0ee5d 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.h +++ b/sherpa-onnx/csrc/jieba-lexicon.h @@ -10,11 +10,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - #include "sherpa-onnx/csrc/offline-tts-frontend.h" #include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" @@ -27,13 +22,7 @@ class JiebaLexicon : public OfflineTtsFrontend { const std::string &dict_dir, const OfflineTtsVitsModelMetaData &meta_data, bool debug); -#if __ANDROID_API__ >= 9 - JiebaLexicon(AAssetManager *mgr, const std::string &lexicon, - const std::string &tokens, const std::string &dict_dir, - const OfflineTtsVitsModelMetaData &meta_data); -#endif - - std::vector> ConvertTextToTokenIds( + std::vector ConvertTextToTokenIds( const std::string &text, const std::string &unused_voice = "") const override; diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index 91307a216..c635f7ffd 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -172,7 +172,7 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, } #endif -std::vector> Lexicon::ConvertTextToTokenIds( +std::vector Lexicon::ConvertTextToTokenIds( const std::string &text, const std::string & /*voice*/ /*= ""*/) const { switch (language_) { case Language::kChinese: @@ -187,7 +187,7 @@ std::vector> Lexicon::ConvertTextToTokenIds( return {}; } -std::vector> Lexicon::ConvertTextToTokenIdsChinese( +std::vector Lexicon::ConvertTextToTokenIdsChinese( const std::string &_text) const { std::string text(_text); ToLowerCase(&text); @@ -209,7 +209,7 @@ std::vector> Lexicon::ConvertTextToTokenIdsChinese( fprintf(stderr, "\n"); } - std::vector> ans; + std::vector ans; std::vector this_sentence; int32_t blank = -1; @@ -288,7 +288,7 @@ std::vector> Lexicon::ConvertTextToTokenIdsChinese( return ans; } -std::vector> Lexicon::ConvertTextToTokenIdsNotChinese( +std::vector Lexicon::ConvertTextToTokenIdsNotChinese( const std::string &_text) const { std::string text(_text); ToLowerCase(&text); @@ -311,7 +311,7 @@ std::vector> Lexicon::ConvertTextToTokenIdsNotChinese( int32_t blank = token2id_.at(" "); - std::vector> ans; + std::vector ans; std::vector this_sentence; for (const auto &w : words) { diff --git a/sherpa-onnx/csrc/lexicon.h b/sherpa-onnx/csrc/lexicon.h index 81a510956..2c71ab7e8 100644 --- a/sherpa-onnx/csrc/lexicon.h +++ b/sherpa-onnx/csrc/lexicon.h @@ -36,14 +36,14 @@ class Lexicon : public OfflineTtsFrontend { const std::string &language, bool debug = false); #endif - std::vector> ConvertTextToTokenIds( + std::vector ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const override; private: - std::vector> ConvertTextToTokenIdsNotChinese( + std::vector ConvertTextToTokenIdsNotChinese( const std::string &text) const; - std::vector> ConvertTextToTokenIdsChinese( + std::vector ConvertTextToTokenIdsChinese( const std::string &text) const; void InitLanguage(const std::string &lang); diff --git a/sherpa-onnx/csrc/melo-tts-lexicon.cc b/sherpa-onnx/csrc/melo-tts-lexicon.cc new file mode 100644 index 000000000..b213c4516 --- /dev/null +++ b/sherpa-onnx/csrc/melo-tts-lexicon.cc @@ -0,0 +1,266 @@ +// sherpa-onnx/csrc/melo-tts-lexicon.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/melo-tts-lexicon.h" + +#include +#include // NOLINT +#include + +#include "cppjieba/Jieba.hpp" +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +// implemented in ./lexicon.cc +std::unordered_map ReadTokens(std::istream &is); + +std::vector ConvertTokensToIds( + const std::unordered_map &token2id, + const std::vector &tokens); + +class MeloTtsLexicon::Impl { + public: + Impl(const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug) + : meta_data_(meta_data), debug_(debug) { + std::string dict = dict_dir + "/jieba.dict.utf8"; + std::string hmm = dict_dir + "/hmm_model.utf8"; + std::string user_dict = dict_dir + "/user.dict.utf8"; + std::string idf = dict_dir + "/idf.utf8"; + std::string stop_word = dict_dir + "/stop_words.utf8"; + + AssertFileExists(dict); + AssertFileExists(hmm); + AssertFileExists(user_dict); + AssertFileExists(idf); + AssertFileExists(stop_word); + + jieba_ = + std::make_unique(dict, hmm, user_dict, idf, stop_word); + + { + std::ifstream is(tokens); + InitTokens(is); + } + + { + std::ifstream is(lexicon); + InitLexicon(is); + } + } + + std::vector ConvertTextToTokenIds(const std::string &_text) const { + std::string text = ToLowerCase(_text); + // see + // https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244 + std::regex punct_re{":|、|;"}; + std::string s = std::regex_replace(text, punct_re, ","); + + std::regex punct_re2("。"); + s = std::regex_replace(s, punct_re2, "."); + + std::regex punct_re3("?"); + s = std::regex_replace(s, punct_re3, "?"); + + std::regex punct_re4("!"); + s = std::regex_replace(s, punct_re4, "!"); + + std::vector words; + bool is_hmm = true; + jieba_->Cut(text, words, is_hmm); + + if (debug_) { + SHERPA_ONNX_LOGE("input text: %s", text.c_str()); + SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str()); + + std::ostringstream os; + std::string sep = ""; + for (const auto &w : words) { + os << sep << w; + sep = "_"; + } + + SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str()); + } + + std::vector ans; + TokenIDs this_sentence; + + int32_t blank = token2id_.at("_"); + for (const auto &w : words) { + auto ids = ConvertWordToIds(w); + if (ids.tokens.empty()) { + SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str()); + continue; + } + + this_sentence.tokens.insert(this_sentence.tokens.end(), + ids.tokens.begin(), ids.tokens.end()); + this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(), + ids.tones.end()); + + if (w == "." || w == "!" || w == "?" || w == ",") { + ans.push_back(std::move(this_sentence)); + this_sentence = {}; + } + } // for (const auto &w : words) + + if (!this_sentence.tokens.empty()) { + ans.push_back(std::move(this_sentence)); + } + + return ans; + } + + private: + TokenIDs ConvertWordToIds(const std::string &w) const { + if (word2ids_.count(w)) { + return word2ids_.at(w); + } + + if (token2id_.count(w)) { + return {{token2id_.at(w)}, {0}}; + } + + TokenIDs ans; + + std::vector words = SplitUtf8(w); + for (const auto &word : words) { + if (word2ids_.count(word)) { + auto ids = ConvertWordToIds(word); + ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(), + ids.tokens.end()); + ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end()); + } + } + + return ans; + } + + void InitTokens(std::istream &is) { + token2id_ = ReadTokens(is); + token2id_[" "] = token2id_["_"]; + + std::vector> puncts = { + {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}}; + + for (const auto &p : puncts) { + if (token2id_.count(p.first) && !token2id_.count(p.second)) { + token2id_[p.second] = token2id_[p.first]; + } + + if (!token2id_.count(p.first) && token2id_.count(p.second)) { + token2id_[p.first] = token2id_[p.second]; + } + } + + if (!token2id_.count("、") && token2id_.count(",")) { + token2id_["、"] = token2id_[","]; + } + } + + void InitLexicon(std::istream &is) { + std::string word; + std::vector token_list; + + std::vector phone_list; + std::vector tone_list; + + std::string line; + std::string phone; + int32_t line_num = 0; + + while (std::getline(is, line)) { + ++line_num; + + std::istringstream iss(line); + + token_list.clear(); + phone_list.clear(); + tone_list.clear(); + + iss >> word; + ToLowerCase(&word); + + if (word2ids_.count(word)) { + SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.", + word.c_str(), line_num, line.c_str()); + continue; + } + + while (iss >> phone) { + token_list.push_back(std::move(phone)); + } + + if ((token_list.size() & 1) != 0) { + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); + exit(-1); + } + + int32_t num_phones = token_list.size() / 2; + phone_list.reserve(num_phones); + tone_list.reserve(num_phones); + + for (int32_t i = 0; i != num_phones; ++i) { + phone_list.push_back(std::move(token_list[i])); + tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr)); + if (tone_list.back() < 0 || tone_list.back() > 50) { + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); + exit(-1); + } + } + + std::vector ids = ConvertTokensToIds(token2id_, phone_list); + if (ids.empty()) { + continue; + } + + if (ids.size() != num_phones) { + SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str()); + exit(-1); + } + + std::vector ids64{ids.begin(), ids.end()}; + + word2ids_.insert( + {std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}}); + } + + word2ids_["呣"] = word2ids_["母"]; + word2ids_["嗯"] = word2ids_["恩"]; + } + + private: + // lexicon.txt is saved in word2ids_ + std::unordered_map word2ids_; + + // tokens.txt is saved in token2id_ + std::unordered_map token2id_; + + OfflineTtsVitsModelMetaData meta_data_; + + std::unique_ptr jieba_; + bool debug_ = false; +}; + +MeloTtsLexicon::~MeloTtsLexicon() = default; + +MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon, + const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, + bool debug) + : impl_(std::make_unique(lexicon, tokens, dict_dir, meta_data, + debug)) {} + +std::vector MeloTtsLexicon::ConvertTextToTokenIds( + const std::string &text, const std::string & /*unused_voice = ""*/) const { + return impl_->ConvertTextToTokenIds(text); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/melo-tts-lexicon.h b/sherpa-onnx/csrc/melo-tts-lexicon.h new file mode 100644 index 000000000..261f3412e --- /dev/null +++ b/sherpa-onnx/csrc/melo-tts-lexicon.h @@ -0,0 +1,36 @@ +// sherpa-onnx/csrc/melo-tts-lexicon.h +// +// Copyright (c) 2022-2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ +#define SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-tts-frontend.h" +#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" + +namespace sherpa_onnx { + +class MeloTtsLexicon : public OfflineTtsFrontend { + public: + ~MeloTtsLexicon() override; + MeloTtsLexicon(const std::string &lexicon, const std::string &tokens, + const std::string &dict_dir, + const OfflineTtsVitsModelMetaData &meta_data, bool debug); + + std::vector ConvertTextToTokenIds( + const std::string &text, + const std::string &unused_voice = "") const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_MELO_TTS_LEXICON_H_ diff --git a/sherpa-onnx/csrc/offline-tts-character-frontend.cc b/sherpa-onnx/csrc/offline-tts-character-frontend.cc index 857200e9c..c92240581 100644 --- a/sherpa-onnx/csrc/offline-tts-character-frontend.cc +++ b/sherpa-onnx/csrc/offline-tts-character-frontend.cc @@ -94,8 +94,7 @@ OfflineTtsCharacterFrontend::OfflineTtsCharacterFrontend( #endif -std::vector> -OfflineTtsCharacterFrontend::ConvertTextToTokenIds( +std::vector OfflineTtsCharacterFrontend::ConvertTextToTokenIds( const std::string &_text, const std::string & /*voice = ""*/) const { // see // https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/utils/text/tokenizer.py#L87 @@ -112,7 +111,7 @@ OfflineTtsCharacterFrontend::ConvertTextToTokenIds( std::wstring_convert, char32_t> conv; std::u32string s = conv.from_bytes(text); - std::vector> ans; + std::vector ans; std::vector this_sentence; if (add_blank) { diff --git a/sherpa-onnx/csrc/offline-tts-character-frontend.h b/sherpa-onnx/csrc/offline-tts-character-frontend.h index d56ea3125..ffd2bb5f4 100644 --- a/sherpa-onnx/csrc/offline-tts-character-frontend.h +++ b/sherpa-onnx/csrc/offline-tts-character-frontend.h @@ -41,7 +41,7 @@ class OfflineTtsCharacterFrontend : public OfflineTtsFrontend { * If a frontend does not support splitting the text into * sentences, the resulting vector contains only one subvector. */ - std::vector> ConvertTextToTokenIds( + std::vector ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const override; private: diff --git a/sherpa-onnx/csrc/offline-tts-frontend.cc b/sherpa-onnx/csrc/offline-tts-frontend.cc new file mode 100644 index 000000000..f083c75e3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-frontend.cc @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-tts-frontend.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-frontend.h" + +#include +#include + +namespace sherpa_onnx { + +std::string TokenIDs::ToString() const { + std::ostringstream os; + os << "TokenIDs("; + os << "tokens=["; + std::string sep; + for (auto i : tokens) { + os << sep << i; + sep = ", "; + } + os << "], "; + + os << "tones=["; + sep = {}; + for (auto i : tones) { + os << sep << i; + sep = ", "; + } + os << "]"; + os << ")"; + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-frontend.h b/sherpa-onnx/csrc/offline-tts-frontend.h index 9f116f125..6ea4c5dfa 100644 --- a/sherpa-onnx/csrc/offline-tts-frontend.h +++ b/sherpa-onnx/csrc/offline-tts-frontend.h @@ -8,8 +8,28 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" + namespace sherpa_onnx { +struct TokenIDs { + TokenIDs() = default; + + /*implicit*/ TokenIDs(const std::vector &tokens) // NOLINT + : tokens{tokens} {} + + TokenIDs(const std::vector &tokens, + const std::vector &tones) + : tokens{tokens}, tones{tones} {} + + std::string ToString() const; + + std::vector tokens; + + // Used only in MeloTTS + std::vector tones; +}; + class OfflineTtsFrontend { public: virtual ~OfflineTtsFrontend() = default; @@ -26,7 +46,7 @@ class OfflineTtsFrontend { * If a frontend does not support splitting the text into sentences, * the resulting vector contains only one subvector. */ - virtual std::vector> ConvertTextToTokenIds( + virtual std::vector ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const = 0; }; diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index c2f6e602c..fb43a88ab 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -22,6 +22,7 @@ #include "sherpa-onnx/csrc/jieba-lexicon.h" #include "sherpa-onnx/csrc/lexicon.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/melo-tts-lexicon.h" #include "sherpa-onnx/csrc/offline-tts-character-frontend.h" #include "sherpa-onnx/csrc/offline-tts-frontend.h" #include "sherpa-onnx/csrc/offline-tts-impl.h" @@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } } - std::vector> x = + std::vector token_ids = frontend_->ConvertTextToTokenIds(text, meta_data.voice); - if (x.empty() || (x.size() == 1 && x[0].empty())) { + if (token_ids.empty() || + (token_ids.size() == 1 && token_ids[0].tokens.empty())) { SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str()); return {}; } + std::vector> x; + std::vector> tones; + + x.reserve(token_ids.size()); + + for (auto &i : token_ids) { + x.push_back(std::move(i.tokens)); + } + + if (!token_ids[0].tones.empty()) { + tones.reserve(token_ids.size()); + for (auto &i : token_ids) { + tones.push_back(std::move(i.tones)); + } + } + // TODO(fangjun): add blank inside the frontend, not here if (meta_data.add_blank && config_.model.vits.data_dir.empty() && meta_data.frontend != "characters") { for (auto &k : x) { k = AddBlank(k); } + + for (auto &k : tones) { + k = AddBlank(k); + } } int32_t x_size = static_cast(x.size()); if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { - auto ans = Process(x, sid, speed); + auto ans = Process(x, tones, sid, speed); if (callback) { callback(ans.samples.data(), ans.samples.size(), 1.0); } @@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { // the input text is too long, we process sentences within it in batches // to avoid OOM. Batch size is config_.max_num_sentences - std::vector> batch; + std::vector> batch_x; + std::vector> batch_tones; + int32_t batch_size = config_.max_num_sentences; - batch.reserve(config_.max_num_sentences); + batch_x.reserve(config_.max_num_sentences); + batch_tones.reserve(config_.max_num_sentences); int32_t num_batches = x_size / batch_size; if (config_.model.debug) { @@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { int32_t k = 0; for (int32_t b = 0; b != num_batches && should_continue; ++b) { - batch.clear(); + batch_x.clear(); + batch_tones.clear(); for (int32_t i = 0; i != batch_size; ++i, ++k) { - batch.push_back(std::move(x[k])); + batch_x.push_back(std::move(x[k])); + + if (!tones.empty()) { + batch_tones.push_back(std::move(tones[k])); + } } - auto audio = Process(batch, sid, speed); + auto audio = Process(batch_x, batch_tones, sid, speed); ans.sample_rate = audio.sample_rate; ans.samples.insert(ans.samples.end(), audio.samples.begin(), audio.samples.end()); @@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } } - batch.clear(); + batch_x.clear(); + batch_tones.clear(); while (k < static_cast(x.size()) && should_continue) { - batch.push_back(std::move(x[k])); + batch_x.push_back(std::move(x[k])); + if (!tones.empty()) { + batch_tones.push_back(std::move(tones[k])); + } + ++k; } - if (!batch.empty()) { - auto audio = Process(batch, sid, speed); + if (!batch_x.empty()) { + auto audio = Process(batch_x, batch_tones, sid, speed); ans.sample_rate = audio.sample_rate; ans.samples.insert(ans.samples.end(), audio.samples.begin(), audio.samples.end()); @@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { if (meta_data.frontend == "characters") { frontend_ = std::make_unique( config_.model.vits.tokens, meta_data); + } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() && + meta_data.is_melo_tts) { + frontend_ = std::make_unique( + config_.model.vits.lexicon, config_.model.vits.tokens, + config_.model.vits.dict_dir, model_->GetMetaData(), + config_.model.debug); } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { frontend_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, @@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } GeneratedAudio Process(const std::vector> &tokens, + const std::vector> &tones, int32_t sid, float speed) const { int32_t num_tokens = 0; for (const auto &k : tokens) { @@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { x.insert(x.end(), k.begin(), k.end()); } + std::vector tone_list; + if (!tones.empty()) { + tone_list.reserve(num_tokens); + for (const auto &k : tones) { + tone_list.insert(tone_list.end(), k.begin(), k.end()); + } + } + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { Ort::Value x_tensor = Ort::Value::CreateTensor( memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); - Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed); + Ort::Value tones_tensor{nullptr}; + if (!tones.empty()) { + tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(), + tone_list.size(), x_shape.data(), + x_shape.size()); + } + + Ort::Value audio{nullptr}; + if (tones.empty()) { + audio = model_->Run(std::move(x_tensor), sid, speed); + } else { + audio = + model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed); + } std::vector audio_shape = audio.GetTensorTypeAndShapeInfo().GetShape(); diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h b/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h index 621e0e555..5ce00d745 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model-metadata.h @@ -21,6 +21,7 @@ struct OfflineTtsVitsModelMetaData { bool is_piper = false; bool is_coqui = false; bool is_icefall = false; + bool is_melo_tts = false; // for Chinese TTS models from // https://github.com/Plachtaa/VITS-fast-fine-tuning @@ -33,6 +34,10 @@ struct OfflineTtsVitsModelMetaData { int32_t use_eos_bos = 0; int32_t pad_id = 0; + // for melo tts + int32_t speaker_id = 0; + int32_t version = 0; + std::string punctuations; std::string language; std::string voice; diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index d73a453c2..3a2bdfdb4 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -45,6 +45,64 @@ class OfflineTtsVitsModel::Impl { return RunVits(std::move(x), sid, speed); } + Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid, float speed) { + // For MeloTTS, we hardcode sid to the one contained in the meta data + sid = meta_data_.speaker_id; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int64_t len = x_shape[1]; + int64_t len_shape = 1; + + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); + + int64_t scale_shape = 1; + float noise_scale = config_.vits.noise_scale; + float length_scale = config_.vits.length_scale; + float noise_scale_w = config_.vits.noise_scale_w; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + + Ort::Value noise_scale_tensor = + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + + Ort::Value length_scale_tensor = Ort::Value::CreateTensor( + memory_info, &length_scale, 1, &scale_shape, 1); + + Ort::Value noise_scale_w_tensor = Ort::Value::CreateTensor( + memory_info, &noise_scale_w, 1, &scale_shape, 1); + + Ort::Value sid_tensor = + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1); + + std::vector inputs; + inputs.reserve(7); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(tones)); + inputs.push_back(std::move(sid_tensor)); + inputs.push_back(std::move(noise_scale_tensor)); + inputs.push_back(std::move(length_scale_tensor)); + inputs.push_back(std::move(noise_scale_w_tensor)); + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + const OfflineTtsVitsModelMetaData &GetMetaData() const { return meta_data_; } private: @@ -83,6 +141,10 @@ class OfflineTtsVitsModel::Impl { SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.add_blank, "add_blank", 0); + + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.speaker_id, "speaker_id", + 0); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 0); SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.punctuations, "punctuation", ""); @@ -115,6 +177,22 @@ class OfflineTtsVitsModel::Impl { if (comment.find("icefall") != std::string::npos) { meta_data_.is_icefall = true; } + + if (comment.find("melo") != std::string::npos) { + meta_data_.is_melo_tts = true; + int32_t expected_version = 2; + if (meta_data_.version < expected_version) { + SHERPA_ONNX_LOGE( + "Please download the latest MeloTTS model and retry. Current " + "version: %d. Expected version: %d", + meta_data_.version, expected_version); + exit(-1); + } + + // NOTE(fangjun): + // version 0 is the first version + // version 2: add jieba=1 to the metadata + } } Ort::Value RunVitsPiperOrCoqui(Ort::Value x, int64_t sid, float speed) { @@ -269,6 +347,12 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, return impl_->Run(std::move(x), sid, speed); } +Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones, + int64_t sid /*= 0*/, + float speed /*= 1.0*/) { + return impl_->Run(std::move(x), std::move(tones), sid, speed); +} + const OfflineTtsVitsModelMetaData &OfflineTtsVitsModel::GetMetaData() const { return impl_->GetMetaData(); } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index 7d51efa2c..543963c9d 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -40,6 +40,10 @@ class OfflineTtsVitsModel { */ Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0); + // This is for MeloTTS + Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0, + float speed = 1.0); + const OfflineTtsVitsModelMetaData &GetMetaData() const; private: diff --git a/sherpa-onnx/csrc/offline-whisper-decoder.h b/sherpa-onnx/csrc/offline-whisper-decoder.h index 3babb3824..9cb5088e6 100644 --- a/sherpa-onnx/csrc/offline-whisper-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-decoder.h @@ -5,8 +5,8 @@ #ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ #define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ -#include #include +#include #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-whisper-model-config.h" @@ -36,7 +36,6 @@ class OfflineWhisperDecoder { Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; - }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index 96bb9d971..6ff165a22 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -12,7 +12,8 @@ namespace sherpa_onnx { -void OfflineWhisperGreedySearchDecoder::SetConfig(const OfflineWhisperModelConfig &config) { +void OfflineWhisperGreedySearchDecoder::SetConfig( + const OfflineWhisperModelConfig &config) { config_ = config; } @@ -135,9 +136,9 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, const auto &id2lang = model_->GetID2Lang(); if (id2lang.count(initial_tokens[1])) { - ans[0].lang = id2lang.at(initial_tokens[1]); + ans[0].lang = id2lang.at(initial_tokens[1]); } else { - ans[0].lang = ""; + ans[0].lang = ""; } ans[0].tokens = std::move(predicted_tokens); diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 5d9f3745e..ec0475b20 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -153,15 +153,21 @@ Ort::Value View(Ort::Value *v) { } } +template void Print1D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); - const float *d = v->GetTensorData(); + const T *d = v->GetTensorData(); + std::ostringstream os; for (int32_t i = 0; i != static_cast(shape[0]); ++i) { - fprintf(stderr, "%.3f ", d[i]); + os << *d << " "; } - fprintf(stderr, "\n"); + os << "\n"; + fprintf(stderr, "%s\n", os.str().c_str()); } +template void Print1D(Ort::Value *v); +template void Print1D(Ort::Value *v); + template void Print2D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index b179b378d..da0abab82 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -69,6 +69,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); Ort::Value View(Ort::Value *v); // Print a 1-D tensor to stderr +template void Print1D(Ort::Value *v); // Print a 2-D tensor to stderr diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc index aa7b9a2c5..8ba101cc3 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.cc +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.cc @@ -214,7 +214,7 @@ PiperPhonemizeLexicon::PiperPhonemizeLexicon( } #endif -std::vector> PiperPhonemizeLexicon::ConvertTextToTokenIds( +std::vector PiperPhonemizeLexicon::ConvertTextToTokenIds( const std::string &text, const std::string &voice /*= ""*/) const { piper::eSpeakPhonemeConfig config; @@ -232,7 +232,7 @@ std::vector> PiperPhonemizeLexicon::ConvertTextToTokenIds( piper::phonemize_eSpeak(text, config, phonemes); } - std::vector> ans; + std::vector ans; std::vector phoneme_ids; diff --git a/sherpa-onnx/csrc/piper-phonemize-lexicon.h b/sherpa-onnx/csrc/piper-phonemize-lexicon.h index 842d80e0c..34922de29 100644 --- a/sherpa-onnx/csrc/piper-phonemize-lexicon.h +++ b/sherpa-onnx/csrc/piper-phonemize-lexicon.h @@ -30,7 +30,7 @@ class PiperPhonemizeLexicon : public OfflineTtsFrontend { const OfflineTtsVitsModelMetaData &meta_data); #endif - std::vector> ConvertTextToTokenIds( + std::vector ConvertTextToTokenIds( const std::string &text, const std::string &voice = "") const override; private: diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 093465063..fb5932c47 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -31,8 +31,8 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) { api.ReleaseStatus(status); } -static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, - const std::string &provider_str, +static Ort::SessionOptions GetSessionOptionsImpl( + int32_t num_threads, const std::string &provider_str, const ProviderConfig *provider_config = nullptr) { Provider p = StringToProvider(provider_str); @@ -67,8 +67,9 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, } case Provider::kTRT: { if (provider_config == nullptr) { - SHERPA_ONNX_LOGE("Tensorrt support for Online models ony," - "Must be extended for offline and others"); + SHERPA_ONNX_LOGE( + "Tensorrt support for Online models ony," + "Must be extended for offline and others"); exit(1); } auto trt_config = provider_config->trt_config; @@ -84,29 +85,27 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, std::to_string(trt_config.trt_max_partition_iterations); auto trt_min_subgraph_size = std::to_string(trt_config.trt_min_subgraph_size); - auto trt_fp16_enable = - std::to_string(trt_config.trt_fp16_enable); + auto trt_fp16_enable = std::to_string(trt_config.trt_fp16_enable); auto trt_detailed_build_log = std::to_string(trt_config.trt_detailed_build_log); auto trt_engine_cache_enable = std::to_string(trt_config.trt_engine_cache_enable); auto trt_timing_cache_enable = std::to_string(trt_config.trt_timing_cache_enable); - auto trt_dump_subgraphs = - std::to_string(trt_config.trt_dump_subgraphs); + auto trt_dump_subgraphs = std::to_string(trt_config.trt_dump_subgraphs); std::vector trt_options = { - {"device_id", device_id.c_str()}, - {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, - {"trt_max_partition_iterations", trt_max_partition_iterations.c_str()}, - {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, - {"trt_fp16_enable", trt_fp16_enable.c_str()}, - {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, - {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, - {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, - {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, - {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, - {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()} - }; + {"device_id", device_id.c_str()}, + {"trt_max_workspace_size", trt_max_workspace_size.c_str()}, + {"trt_max_partition_iterations", + trt_max_partition_iterations.c_str()}, + {"trt_min_subgraph_size", trt_min_subgraph_size.c_str()}, + {"trt_fp16_enable", trt_fp16_enable.c_str()}, + {"trt_detailed_build_log", trt_detailed_build_log.c_str()}, + {"trt_engine_cache_enable", trt_engine_cache_enable.c_str()}, + {"trt_engine_cache_path", trt_config.trt_engine_cache_path.c_str()}, + {"trt_timing_cache_enable", trt_timing_cache_enable.c_str()}, + {"trt_timing_cache_path", trt_config.trt_timing_cache_path.c_str()}, + {"trt_dump_subgraphs", trt_dump_subgraphs.c_str()}}; // ToDo : Trt configs // "trt_int8_enable" // "trt_int8_use_native_calibration_table" @@ -151,9 +150,8 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, if (provider_config != nullptr) { options.device_id = provider_config->device; - options.cudnn_conv_algo_search = - OrtCudnnConvAlgoSearch(provider_config->cuda_config - .cudnn_conv_algo_search); + options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch( + provider_config->cuda_config.cudnn_conv_algo_search); } else { options.device_id = 0; // Default OrtCudnnConvAlgoSearchExhaustive is extremely slow @@ -219,22 +217,24 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads, Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, - config.provider_config.provider, &config.provider_config); + config.provider_config.provider, + &config.provider_config); } Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, - const std::string &model_type) { + const std::string &model_type) { /* Transducer models : Only encoder will run with tensorrt, decoder and joiner will run with cuda */ - if(config.provider_config.provider == "trt" && + if (config.provider_config.provider == "trt" && (model_type == "decoder" || model_type == "joiner")) { - return GetSessionOptionsImpl(config.num_threads, - "cuda", &config.provider_config); + return GetSessionOptionsImpl(config.num_threads, "cuda", + &config.provider_config); } return GetSessionOptionsImpl(config.num_threads, - config.provider_config.provider, &config.provider_config); + config.provider_config.provider, + &config.provider_config); } Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) { diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 691a2ff3c..77da79a78 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -5,6 +5,8 @@ #ifndef SHERPA_ONNX_CSRC_SESSION_H_ #define SHERPA_ONNX_CSRC_SESSION_H_ +#include + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" @@ -25,7 +27,7 @@ namespace sherpa_onnx { Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, - const std::string &model_type); + const std::string &model_type); Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc index 701fa6e18..44f413700 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -6,6 +6,7 @@ #include #include +#include #include "Eigen/Dense" #include "sherpa-onnx/csrc/macros.h" diff --git a/sherpa-onnx/csrc/utfcpp-test.cc b/sherpa-onnx/csrc/utfcpp-test.cc index dc9eecc2f..fcc3ae74d 100644 --- a/sherpa-onnx/csrc/utfcpp-test.cc +++ b/sherpa-onnx/csrc/utfcpp-test.cc @@ -11,7 +11,7 @@ namespace sherpa_onnx { TEST(UTF8, Case1) { - std::string hello = "你好, 早上好!世界. hello!。Hallo"; + std::string hello = "你好, 早上好!世界. hello!。Hallo! how are you?"; std::vector ss = SplitUtf8(hello); for (const auto &s : ss) { std::cout << s << "\n";