From 2c2926af7dac752e5359308a2d036b68e00a7c15 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 31 Dec 2024 12:44:14 +0800 Subject: [PATCH] Add C++ runtime for Matcha-TTS (#1627) --- .github/scripts/test-offline-tts.sh | 34 ++ .github/scripts/test-python.sh | 20 + .github/workflows/linux.yaml | 30 +- .github/workflows/macos.yaml | 18 +- python-api-examples/offline-tts-play.py | 93 ++++- python-api-examples/offline-tts.py | 92 ++++- sherpa-onnx/csrc/CMakeLists.txt | 3 + sherpa-onnx/csrc/hifigan-vocoder.cc | 107 +++++ sherpa-onnx/csrc/hifigan-vocoder.h | 38 ++ sherpa-onnx/csrc/jieba-lexicon.cc | 24 +- sherpa-onnx/csrc/jieba-lexicon.h | 4 +- sherpa-onnx/csrc/offline-tts-impl.cc | 27 +- sherpa-onnx/csrc/offline-tts-impl.h | 4 + sherpa-onnx/csrc/offline-tts-matcha-impl.h | 381 ++++++++++++++++++ .../csrc/offline-tts-matcha-model-config.cc | 143 +++++++ .../csrc/offline-tts-matcha-model-config.h | 56 +++ .../csrc/offline-tts-matcha-model-metadata.h | 28 ++ sherpa-onnx/csrc/offline-tts-matcha-model.cc | 198 +++++++++ sherpa-onnx/csrc/offline-tts-matcha-model.h | 39 ++ sherpa-onnx/csrc/offline-tts-model-config.cc | 8 +- sherpa-onnx/csrc/offline-tts-model-config.h | 4 + sherpa-onnx/csrc/offline-tts-vits-impl.h | 28 +- .../csrc/offline-tts-vits-model-config.cc | 23 +- sherpa-onnx/csrc/offline-tts-vits-model.cc | 4 +- sherpa-onnx/csrc/offline-tts-vits-model.h | 2 +- sherpa-onnx/csrc/session.cc | 5 + sherpa-onnx/csrc/session.h | 3 + sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc | 4 + sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../csrc/offline-tts-matcha-model-config.cc | 37 ++ .../csrc/offline-tts-matcha-model-config.h | 16 + .../python/csrc/offline-tts-model-config.cc | 8 +- sherpa-onnx/python/sherpa_onnx/__init__.py | 1 + 33 files changed, 1397 insertions(+), 86 deletions(-) create mode 100644 sherpa-onnx/csrc/hifigan-vocoder.cc create mode 100644 sherpa-onnx/csrc/hifigan-vocoder.h create mode 100644 sherpa-onnx/csrc/offline-tts-matcha-impl.h create mode 100644 sherpa-onnx/csrc/offline-tts-matcha-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-tts-matcha-model-config.h create mode 100644 sherpa-onnx/csrc/offline-tts-matcha-model-metadata.h create mode 100644 sherpa-onnx/csrc/offline-tts-matcha-model.cc create mode 100644 sherpa-onnx/csrc/offline-tts-matcha-model.h create mode 100644 sherpa-onnx/python/csrc/offline-tts-matcha-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h diff --git a/.github/scripts/test-offline-tts.sh b/.github/scripts/test-offline-tts.sh index d3d35df2c..1aa0340a0 100755 --- a/.github/scripts/test-offline-tts.sh +++ b/.github/scripts/test-offline-tts.sh @@ -18,6 +18,40 @@ which $EXE # test waves are saved in ./tts mkdir ./tts +log "------------------------------------------------------------" +log "matcha-icefall-zh-baker" +log "------------------------------------------------------------" +curl -O -SL https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +$EXE \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --num-threads=2 \ + --debug=1 \ + --output-filename=./tts/matcha-baker-zh-1.wav \ + '小米的使命是,始终坚持做"感动人心、价格厚道"的好产品,让全球每个人都能享受科技带来的美好生活' + +$EXE \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --num-threads=2 \ + --debug=1 \ + --output-filename=./tts/matcha-baker-zh-2.wav \ + "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" + +rm hifigan_v2.onnx +rm -rf matcha-icefall-zh-baker + log "------------------------------------------------------------" log "vits-piper-en_US-amy-low" log "------------------------------------------------------------" diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index f93908d45..8bfe2c16f 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -269,6 +269,26 @@ mkdir ./tts log "vits-ljs test" +curl -O -SL https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts.py \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --tts-rule-fsts=./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --output-filename=./tts/test-matcha.wav \ + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。" + +rm -rf matcha-icefall-zh-baker +rm hifigan_v2.onnx + + curl -LS -O https://huggingface.co/csukuangfj/vits-ljs/resolve/main/vits-ljs.onnx curl -LS -O https://huggingface.co/csukuangfj/vits-ljs/resolve/main/lexicon.txt curl -LS -O https://huggingface.co/csukuangfj/vits-ljs/resolve/main/tokens.txt diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 98c88e589..ea64662b5 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -149,6 +149,23 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: install/* + - name: Test offline TTS + if: matrix.with_tts == 'ON' + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-tts + + .github/scripts/test-offline-tts.sh + du -h -d1 . + + - uses: actions/upload-artifact@v4 + if: matrix.with_tts == 'ON' + with: + name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} + path: tts + - name: Test offline Moonshine if: matrix.build_type != 'Debug' shell: bash @@ -309,16 +326,7 @@ jobs: .github/scripts/test-offline-whisper.sh du -h -d1 . - - name: Test offline TTS - if: matrix.with_tts == 'ON' - shell: bash - run: | - du -h -d1 . - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline-tts - .github/scripts/test-offline-tts.sh - du -h -d1 . - name: Test online paraformer shell: bash @@ -367,8 +375,4 @@ jobs: overwrite: true file: sherpa-onnx-*.tar.bz2 - - uses: actions/upload-artifact@v4 - with: - name: tts-generated-test-files-${{ matrix.build_type }}-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} - path: tts diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index fd26d5b9f..e6f627e15 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -121,6 +121,15 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline TTS + if: matrix.with_tts == 'ON' + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-tts + + .github/scripts/test-offline-tts.sh + - name: Test offline Moonshine if: matrix.build_type != 'Debug' shell: bash @@ -226,15 +235,6 @@ jobs: .github/scripts/test-kws.sh - - name: Test offline TTS - if: matrix.with_tts == 'ON' - shell: bash - run: | - export PATH=$PWD/build/bin:$PATH - export EXE=sherpa-onnx-offline-tts - - .github/scripts/test-offline-tts.sh - - name: Test online paraformer shell: bash run: | diff --git a/python-api-examples/offline-tts-play.py b/python-api-examples/offline-tts-play.py index 8457fc45c..e8350ea47 100755 --- a/python-api-examples/offline-tts-play.py +++ b/python-api-examples/offline-tts-play.py @@ -11,7 +11,7 @@ Usage: -Example (1/3) +Example (1/4) wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 tar xf vits-piper-en_US-amy-low.tar.bz2 @@ -23,7 +23,7 @@ --output-filename=./generated.wav \ "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." -Example (2/3) +Example (2/4) wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-zh-aishell3.tar.bz2 tar xvf vits-zh-aishell3.tar.bz2 @@ -37,7 +37,7 @@ --output-filename=./liubei-21.wav \ "勿以恶小而为之,勿以善小而不为。惟贤惟德,能服于人。122334" -Example (3/3) +Example (3/4) wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-vits-zh-ll.tar.bz2 tar xvf sherpa-onnx-vits-zh-ll.tar.bz2 @@ -53,6 +53,24 @@ --output-filename=./test-2.wav \ "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。2024年5月11号,拨打110或者18920240511。123456块钱。" +Example (4/4) + +curl -O -SL https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts-play.py \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --tts-rule-fsts=./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --output-filename=./test-matcha.wav \ + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。" + You can find more models at https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models @@ -84,14 +102,11 @@ sys.exit(-1) -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - +def add_vits_args(parser): parser.add_argument( "--vits-model", type=str, + default="", help="Path to vits model.onnx", ) @@ -124,6 +139,60 @@ def get_args(): help="Path to the dict directory for models using jieba", ) + +def add_matcha_args(parser): + parser.add_argument( + "--matcha-acoustic-model", + type=str, + default="", + help="Path to model.onnx for matcha", + ) + + parser.add_argument( + "--matcha-vocoder", + type=str, + default="", + help="Path to vocoder for matcha", + ) + + parser.add_argument( + "--matcha-lexicon", + type=str, + default="", + help="Path to lexicon.txt for matcha", + ) + + parser.add_argument( + "--matcha-tokens", + type=str, + default="", + help="Path to tokens.txt for matcha", + ) + + parser.add_argument( + "--matcha-data-dir", + type=str, + default="", + help="""Path to the dict directory of espeak-ng. If it is specified, + --matcha-lexicon and --matcha-tokens are ignored""", + ) + + parser.add_argument( + "--matcha-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + add_vits_args(parser) + add_matcha_args(parser) + parser.add_argument( "--tts-rule-fsts", type=str, @@ -313,6 +382,14 @@ def main(): dict_dir=args.vits_dict_dir, tokens=args.vits_tokens, ), + matcha=sherpa_onnx.OfflineTtsMatchaModelConfig( + acoustic_model=args.matcha_acoustic_model, + vocoder=args.matcha_vocoder, + lexicon=args.matcha_lexicon, + tokens=args.matcha_tokens, + data_dir=args.matcha_data_dir, + dict_dir=args.matcha_dict_dir, + ), provider=args.provider, debug=args.debug, num_threads=args.num_threads, diff --git a/python-api-examples/offline-tts.py b/python-api-examples/offline-tts.py index 18ea638e8..aa1cce935 100755 --- a/python-api-examples/offline-tts.py +++ b/python-api-examples/offline-tts.py @@ -12,7 +12,7 @@ Usage: -Example (1/3) +Example (1/4) wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 tar xf vits-piper-en_US-amy-low.tar.bz2 @@ -24,7 +24,7 @@ --output-filename=./generated.wav \ "Today as always, men fall into two groups: slaves and free men. Whoever does not have two-thirds of his day for himself, is a slave, whatever he may be: a statesman, a businessman, an official, or a scholar." -Example (2/3) +Example (2/4) wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-icefall-zh-aishell3.tar.bz2 tar xvf vits-icefall-zh-aishell3.tar.bz2 @@ -38,7 +38,7 @@ --output-filename=./liubei-21.wav \ "勿以恶小而为之,勿以善小而不为。惟贤惟德,能服于人。122334" -Example (3/3) +Example (3/4) wget https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/sherpa-onnx-vits-zh-ll.tar.bz2 tar xvf sherpa-onnx-vits-zh-ll.tar.bz2 @@ -54,6 +54,23 @@ --output-filename=./test-2.wav \ "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。2024年5月11号,拨打110或者18920240511。123456块钱。" +Example (4/4) + +curl -O -SL https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/matcha-icefall-zh-baker.tar.bz2 +tar xvf matcha-icefall-zh-baker.tar.bz2 +rm matcha-icefall-zh-baker.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/vocoder-models/hifigan_v2.onnx + +python3 ./python-api-examples/offline-tts.py \ + --matcha-acoustic-model=./matcha-icefall-zh-baker/model-steps-3.onnx \ + --matcha-vocoder=./hifigan_v2.onnx \ + --matcha-lexicon=./matcha-icefall-zh-baker/lexicon.txt \ + --matcha-tokens=./matcha-icefall-zh-baker/tokens.txt \ + --tts-rule-fsts=./matcha-icefall-zh-baker/phone.fst,./matcha-icefall-zh-baker/date.fst,./matcha-icefall-zh-baker/number.fst \ + --matcha-dict-dir=./matcha-icefall-zh-baker/dict \ + --output-filename=./test-matcha.wav \ + "某某银行的副行长和一些行政领导表示,他们去过长江和长白山; 经济不断增长。2024年12月31号,拨打110或者18920240511。123456块钱。" You can find more models at https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models @@ -71,14 +88,11 @@ import soundfile as sf -def get_args(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - +def add_vits_args(parser): parser.add_argument( "--vits-model", type=str, + default="", help="Path to vits model.onnx", ) @@ -111,6 +125,60 @@ def get_args(): help="Path to the dict directory for models using jieba", ) + +def add_matcha_args(parser): + parser.add_argument( + "--matcha-acoustic-model", + type=str, + default="", + help="Path to model.onnx for matcha", + ) + + parser.add_argument( + "--matcha-vocoder", + type=str, + default="", + help="Path to vocoder for matcha", + ) + + parser.add_argument( + "--matcha-lexicon", + type=str, + default="", + help="Path to lexicon.txt for matcha", + ) + + parser.add_argument( + "--matcha-tokens", + type=str, + default="", + help="Path to tokens.txt for matcha", + ) + + parser.add_argument( + "--matcha-data-dir", + type=str, + default="", + help="""Path to the dict directory of espeak-ng. If it is specified, + --matcha-lexicon and --matcha-tokens are ignored""", + ) + + parser.add_argument( + "--matcha-dict-dir", + type=str, + default="", + help="Path to the dict directory for models using jieba", + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + add_vits_args(parser) + add_matcha_args(parser) + parser.add_argument( "--tts-rule-fsts", type=str, @@ -196,6 +264,14 @@ def main(): dict_dir=args.vits_dict_dir, tokens=args.vits_tokens, ), + matcha=sherpa_onnx.OfflineTtsMatchaModelConfig( + acoustic_model=args.matcha_acoustic_model, + vocoder=args.matcha_vocoder, + lexicon=args.matcha_lexicon, + tokens=args.matcha_tokens, + data_dir=args.matcha_data_dir, + dict_dir=args.matcha_dict_dir, + ), provider=args.provider, debug=args.debug, num_threads=args.num_threads, diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 3850c8eb3..f146b09e2 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -151,12 +151,15 @@ list(APPEND sources if(SHERPA_ONNX_ENABLE_TTS) list(APPEND sources + hifigan-vocoder.cc jieba-lexicon.cc lexicon.cc melo-tts-lexicon.cc offline-tts-character-frontend.cc offline-tts-frontend.cc offline-tts-impl.cc + offline-tts-matcha-model-config.cc + offline-tts-matcha-model.cc offline-tts-model-config.cc offline-tts-vits-model-config.cc offline-tts-vits-model.cc diff --git a/sherpa-onnx/csrc/hifigan-vocoder.cc b/sherpa-onnx/csrc/hifigan-vocoder.cc new file mode 100644 index 000000000..b2ff20788 --- /dev/null +++ b/sherpa-onnx/csrc/hifigan-vocoder.cc @@ -0,0 +1,107 @@ +// sherpa-onnx/csrc/hifigan-vocoder.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/hifigan-vocoder.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class HifiganVocoder::Impl { + public: + explicit Impl(int32_t num_threads, const std::string &provider, + const std::string &model) + : env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(num_threads, provider)), + allocator_{} { + auto buf = ReadFile(model); + Init(buf.data(), buf.size()); + } + + template + explicit Impl(Manager *mgr, int32_t num_threads, const std::string &provider, + const std::string &model) + : env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(num_threads, provider)), + allocator_{} { + auto buf = ReadFile(mgr, model); + Init(buf.data(), buf.size()); + } + + Ort::Value Run(Ort::Value mel) const { + auto out = sess_->Run({}, input_names_ptr_.data(), &mel, 1, + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + } + + private: + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; +}; + +HifiganVocoder::HifiganVocoder(int32_t num_threads, const std::string &provider, + const std::string &model) + : impl_(std::make_unique(num_threads, provider, model)) {} + +template +HifiganVocoder::HifiganVocoder(Manager *mgr, int32_t num_threads, + const std::string &provider, + const std::string &model) + : impl_(std::make_unique(mgr, num_threads, provider, model)) {} + +HifiganVocoder::~HifiganVocoder() = default; + +Ort::Value HifiganVocoder::Run(Ort::Value mel) const { + return impl_->Run(std::move(mel)); +} + +#if __ANDROID_API__ >= 9 +template HifiganVocoder::HifiganVocoder(AAssetManager *mgr, int32_t num_threads, + const std::string &provider, + const std::string &model); +#endif + +#if __OHOS__ +template HifiganVocoder::HifiganVocoder(NativeResourceManager *mgr, + int32_t num_threads, + const std::string &provider, + const std::string &model); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/hifigan-vocoder.h b/sherpa-onnx/csrc/hifigan-vocoder.h new file mode 100644 index 000000000..3d10a2428 --- /dev/null +++ b/sherpa-onnx/csrc/hifigan-vocoder.h @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/hifigan-vocoder.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_ +#define SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_ + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +class HifiganVocoder { + public: + ~HifiganVocoder(); + + HifiganVocoder(int32_t num_threads, const std::string &provider, + const std::string &model); + + template + HifiganVocoder(Manager *mgr, int32_t num_threads, const std::string &provider, + const std::string &model); + + /** @param mel A float32 tensor of shape (batch_size, feat_dim, num_frames). + * @return Return a float32 tensor of shape (batch_size, num_samples). + */ + Ort::Value Run(Ort::Value mel) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_ diff --git a/sherpa-onnx/csrc/jieba-lexicon.cc b/sherpa-onnx/csrc/jieba-lexicon.cc index a53f057f0..9ea11f46f 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.cc +++ b/sherpa-onnx/csrc/jieba-lexicon.cc @@ -19,9 +19,8 @@ namespace sherpa_onnx { class JiebaLexicon::Impl { public: Impl(const std::string &lexicon, const std::string &tokens, - const std::string &dict_dir, - const OfflineTtsVitsModelMetaData &meta_data, bool debug) - : meta_data_(meta_data), debug_(debug) { + const std::string &dict_dir, bool debug) + : debug_(debug) { std::string dict = dict_dir + "/jieba.dict.utf8"; std::string hmm = dict_dir + "/hmm_model.utf8"; std::string user_dict = dict_dir + "/user.dict.utf8"; @@ -84,7 +83,6 @@ class JiebaLexicon::Impl { std::vector ans; std::vector this_sentence; - int32_t blank = token2id_.at(" "); for (const auto &w : words) { auto ids = ConvertWordToIds(w); if (ids.empty()) { @@ -93,7 +91,6 @@ class JiebaLexicon::Impl { } this_sentence.insert(this_sentence.end(), ids.begin(), ids.end()); - this_sentence.push_back(blank); if (w == "。" || w == "!" || w == "?" || w == ",") { ans.emplace_back(std::move(this_sentence)); @@ -135,7 +132,9 @@ class JiebaLexicon::Impl { token2id_ = ReadTokens(is); std::vector> puncts = { - {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}}; + {",", ","}, {".", "。"}, {"!", "!"}, {"?", "?"}, {":", ":"}, + {"\"", "“"}, {"\"", "”"}, {"'", "‘"}, {"'", "’"}, {";", ";"}, + }; for (const auto &p : puncts) { if (token2id_.count(p.first) && !token2id_.count(p.second)) { @@ -150,6 +149,10 @@ class JiebaLexicon::Impl { if (!token2id_.count("、") && token2id_.count(",")) { token2id_["、"] = token2id_[","]; } + + if (!token2id_.count(";") && token2id_.count(",")) { + token2id_[";"] = token2id_[","]; + } } void InitLexicon(std::istream &is) { @@ -195,8 +198,6 @@ class JiebaLexicon::Impl { // tokens.txt is saved in token2id_ std::unordered_map token2id_; - OfflineTtsVitsModelMetaData meta_data_; - std::unique_ptr jieba_; bool debug_ = false; }; @@ -205,11 +206,8 @@ JiebaLexicon::~JiebaLexicon() = default; JiebaLexicon::JiebaLexicon(const std::string &lexicon, const std::string &tokens, - const std::string &dict_dir, - const OfflineTtsVitsModelMetaData &meta_data, - bool debug) - : impl_(std::make_unique(lexicon, tokens, dict_dir, meta_data, - debug)) {} + const std::string &dict_dir, bool debug) + : impl_(std::make_unique(lexicon, tokens, dict_dir, debug)) {} std::vector JiebaLexicon::ConvertTextToTokenIds( const std::string &text, const std::string & /*unused_voice = ""*/) const { diff --git a/sherpa-onnx/csrc/jieba-lexicon.h b/sherpa-onnx/csrc/jieba-lexicon.h index d02e0ee5d..9de104357 100644 --- a/sherpa-onnx/csrc/jieba-lexicon.h +++ b/sherpa-onnx/csrc/jieba-lexicon.h @@ -11,7 +11,6 @@ #include #include "sherpa-onnx/csrc/offline-tts-frontend.h" -#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h" namespace sherpa_onnx { @@ -19,8 +18,7 @@ class JiebaLexicon : public OfflineTtsFrontend { public: ~JiebaLexicon() override; JiebaLexicon(const std::string &lexicon, const std::string &tokens, - const std::string &dict_dir, - const OfflineTtsVitsModelMetaData &meta_data, bool debug); + const std::string &dict_dir, bool debug); std::vector ConvertTextToTokenIds( const std::string &text, diff --git a/sherpa-onnx/csrc/offline-tts-impl.cc b/sherpa-onnx/csrc/offline-tts-impl.cc index 62b6eebba..92ccb7fdb 100644 --- a/sherpa-onnx/csrc/offline-tts-impl.cc +++ b/sherpa-onnx/csrc/offline-tts-impl.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/offline-tts-impl.h" #include +#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -15,21 +16,39 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/offline-tts-matcha-impl.h" #include "sherpa-onnx/csrc/offline-tts-vits-impl.h" namespace sherpa_onnx { +std::vector OfflineTtsImpl::AddBlank(const std::vector &x, + int32_t blank_id /*= 0*/) const { + // we assume the blank ID is 0 + std::vector buffer(x.size() * 2 + 1, blank_id); + int32_t i = 1; + for (auto k : x) { + buffer[i] = k; + i += 2; + } + return buffer; +} + std::unique_ptr OfflineTtsImpl::Create( const OfflineTtsConfig &config) { - // TODO(fangjun): Support other types - return std::make_unique(config); + if (!config.model.vits.model.empty()) { + return std::make_unique(config); + } + return std::make_unique(config); } template std::unique_ptr OfflineTtsImpl::Create( Manager *mgr, const OfflineTtsConfig &config) { - // TODO(fangjun): Support other types - return std::make_unique(mgr, config); + if (!config.model.vits.model.empty()) { + return std::make_unique(mgr, config); + } + + return std::make_unique(mgr, config); } #if __ANDROID_API__ >= 9 diff --git a/sherpa-onnx/csrc/offline-tts-impl.h b/sherpa-onnx/csrc/offline-tts-impl.h index db8b7162d..061acc747 100644 --- a/sherpa-onnx/csrc/offline-tts-impl.h +++ b/sherpa-onnx/csrc/offline-tts-impl.h @@ -7,6 +7,7 @@ #include #include +#include #include "sherpa-onnx/csrc/offline-tts.h" @@ -32,6 +33,9 @@ class OfflineTtsImpl { // Number of supported speakers. // If it supports only a single speaker, then it return 0 or 1. virtual int32_t NumSpeakers() const = 0; + + std::vector AddBlank(const std::vector &x, + int32_t blank_id = 0) const; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-matcha-impl.h b/sherpa-onnx/csrc/offline-tts-matcha-impl.h new file mode 100644 index 000000000..62c29bb83 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-matcha-impl.h @@ -0,0 +1,381 @@ +// sherpa-onnx/csrc/offline-tts-matcha-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "fst/extensions/far/far.h" +#include "kaldifst/csrc/kaldi-fst-io.h" +#include "kaldifst/csrc/text-normalizer.h" +#include "sherpa-onnx/csrc/hifigan-vocoder.h" +#include "sherpa-onnx/csrc/jieba-lexicon.h" +#include "sherpa-onnx/csrc/lexicon.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/melo-tts-lexicon.h" +#include "sherpa-onnx/csrc/offline-tts-character-frontend.h" +#include "sherpa-onnx/csrc/offline-tts-frontend.h" +#include "sherpa-onnx/csrc/offline-tts-impl.h" +#include "sherpa-onnx/csrc/offline-tts-matcha-model.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/piper-phonemize-lexicon.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineTtsMatchaImpl : public OfflineTtsImpl { + public: + explicit OfflineTtsMatchaImpl(const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(config.model)), + vocoder_(std::make_unique( + config.model.num_threads, config.model.provider, + config.model.matcha.vocoder)) { + InitFrontend(); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + tn_list_.push_back(std::make_unique(f)); + } + } + + if (!config.rule_fars.empty()) { + if (config.model.debug) { + SHERPA_ONNX_LOGE("Loading FST archives"); + } + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + std::unique_ptr> reader( + fst::FarReader::Open(f)); + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } + } + + if (config.model.debug) { + SHERPA_ONNX_LOGE("FST archives loaded!"); + } + } + } + + template + OfflineTtsMatchaImpl(Manager *mgr, const OfflineTtsConfig &config) + : config_(config), + model_(std::make_unique(mgr, config.model)), + vocoder_(std::make_unique( + mgr, config.model.num_threads, config.model.provider, + config.model.matcha.vocoder)) { + InitFrontend(mgr); + + if (!config.rule_fsts.empty()) { + std::vector files; + SplitStringToVector(config.rule_fsts, ",", false, &files); + tn_list_.reserve(files.size()); + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule fst: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule fst: %s", f.c_str()); +#endif + } + auto buf = ReadFile(mgr, f); + std::istrstream is(buf.data(), buf.size()); + tn_list_.push_back(std::make_unique(is)); + } + } + + if (!config.rule_fars.empty()) { + std::vector files; + SplitStringToVector(config.rule_fars, ",", false, &files); + tn_list_.reserve(files.size() + tn_list_.size()); + + for (const auto &f : files) { + if (config.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("rule far: %{public}s", f.c_str()); +#else + SHERPA_ONNX_LOGE("rule far: %s", f.c_str()); +#endif + } + + auto buf = ReadFile(mgr, f); + + std::unique_ptr s( + new std::istrstream(buf.data(), buf.size())); + + std::unique_ptr> reader( + fst::FarReader::Open(std::move(s))); + + for (; !reader->Done(); reader->Next()) { + std::unique_ptr r( + fst::CastOrConvertToConstFst(reader->GetFst()->Copy())); + + tn_list_.push_back( + std::make_unique(std::move(r))); + } // for (; !reader->Done(); reader->Next()) + } // for (const auto &f : files) + } // if (!config.rule_fars.empty()) + } + + int32_t SampleRate() const override { + return model_->GetMetaData().sample_rate; + } + + int32_t NumSpeakers() const override { + return model_->GetMetaData().num_speakers; + } + + GeneratedAudio Generate( + const std::string &_text, int64_t sid = 0, float speed = 1.0, + GeneratedAudioCallback callback = nullptr) const override { + const auto &meta_data = model_->GetMetaData(); + int32_t num_speakers = meta_data.num_speakers; + + if (num_speakers == 0 && sid != 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%{public}d. sid is ignored", + static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%d. sid is ignored", + static_cast(sid)); +#endif + } + + if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This model contains only %{public}d speakers. sid should be in the " + "range [%{public}d, %{public}d]. Given: %{public}d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#else + SHERPA_ONNX_LOGE( + "This model contains only %d speakers. sid should be in the range " + "[%d, %d]. Given: %d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#endif + sid = 0; + } + + std::string text = _text; + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Raw text: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("Raw text: %s", text.c_str()); +#endif + } + + if (!tn_list_.empty()) { + for (const auto &tn : tn_list_) { + text = tn->Normalize(text); + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("After normalizing: %{public}s", text.c_str()); +#else + SHERPA_ONNX_LOGE("After normalizing: %s", text.c_str()); +#endif + } + } + } + + std::vector token_ids = + frontend_->ConvertTextToTokenIds(text, "en-US"); + + if (token_ids.empty() || + (token_ids.size() == 1 && token_ids[0].tokens.empty())) { +#if __OHOS__ + SHERPA_ONNX_LOGE("Failed to convert '%{public}s' to token IDs", + text.c_str()); +#else + SHERPA_ONNX_LOGE("Failed to convert '%s' to token IDs", text.c_str()); +#endif + return {}; + } + + std::vector> x; + + x.reserve(token_ids.size()); + + for (auto &i : token_ids) { + x.push_back(std::move(i.tokens)); + } + + for (auto &k : x) { + k = AddBlank(k, meta_data.pad_id); + } + + int32_t x_size = static_cast(x.size()); + + if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) { + auto ans = Process(x, sid, speed); + if (callback) { + callback(ans.samples.data(), ans.samples.size(), 1.0); + } + return ans; + } + + // the input text is too long, we process sentences within it in batches + // to avoid OOM. Batch size is config_.max_num_sentences + std::vector> batch_x; + + int32_t batch_size = config_.max_num_sentences; + batch_x.reserve(config_.max_num_sentences); + int32_t num_batches = x_size / batch_size; + + if (config_.model.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "Text is too long. Split it into %{public}d batches. batch size: " + "%{public}d. Number of sentences: %{public}d", + num_batches, batch_size, x_size); +#else + SHERPA_ONNX_LOGE( + "Text is too long. Split it into %d batches. batch size: %d. Number " + "of sentences: %d", + num_batches, batch_size, x_size); +#endif + } + + GeneratedAudio ans; + + int32_t should_continue = 1; + + int32_t k = 0; + + for (int32_t b = 0; b != num_batches && should_continue; ++b) { + batch_x.clear(); + for (int32_t i = 0; i != batch_size; ++i, ++k) { + batch_x.push_back(std::move(x[k])); + } + + auto audio = Process(batch_x, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + should_continue = callback(audio.samples.data(), audio.samples.size(), + (b + 1) * 1.0 / num_batches); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + batch_x.clear(); + while (k < static_cast(x.size()) && should_continue) { + batch_x.push_back(std::move(x[k])); + + ++k; + } + + if (!batch_x.empty()) { + auto audio = Process(batch_x, sid, speed); + ans.sample_rate = audio.sample_rate; + ans.samples.insert(ans.samples.end(), audio.samples.begin(), + audio.samples.end()); + if (callback) { + callback(audio.samples.data(), audio.samples.size(), 1.0); + // Caution(fangjun): audio is freed when the callback returns, so users + // should copy the data if they want to access the data after + // the callback returns to avoid segmentation fault. + } + } + + return ans; + } + + private: + template + void InitFrontend(Manager *mgr) {} + + void InitFrontend() { + frontend_ = std::make_unique( + config_.model.matcha.lexicon, config_.model.matcha.tokens, + config_.model.matcha.dict_dir, config_.model.debug); + } + + GeneratedAudio Process(const std::vector> &tokens, + int32_t sid, float speed) const { + int32_t num_tokens = 0; + for (const auto &k : tokens) { + num_tokens += k.size(); + } + + std::vector x; + x.reserve(num_tokens); + for (const auto &k : tokens) { + x.insert(x.end(), k.begin(), k.end()); + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape = {1, static_cast(x.size())}; + Ort::Value x_tensor = Ort::Value::CreateTensor( + memory_info, x.data(), x.size(), x_shape.data(), x_shape.size()); + + Ort::Value mel = model_->Run(std::move(x_tensor), sid, speed); + Ort::Value audio = vocoder_->Run(std::move(mel)); + + std::vector audio_shape = + audio.GetTensorTypeAndShapeInfo().GetShape(); + + int64_t total = 1; + // The output shape may be (1, 1, total) or (1, total) or (total,) + for (auto i : audio_shape) { + total *= i; + } + + const float *p = audio.GetTensorData(); + + GeneratedAudio ans; + ans.sample_rate = model_->GetMetaData().sample_rate; + ans.samples = std::vector(p, p + total); + return ans; + } + + private: + OfflineTtsConfig config_; + std::unique_ptr model_; + std::unique_ptr vocoder_; + std::vector> tn_list_; + std::unique_ptr frontend_; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-tts-matcha-model-config.cc b/sherpa-onnx/csrc/offline-tts-matcha-model-config.cc new file mode 100644 index 000000000..5c736b54d --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-matcha-model-config.cc @@ -0,0 +1,143 @@ +// sherpa-onnx/csrc/offline-tts-matcha-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineTtsMatchaModelConfig::Register(ParseOptions *po) { + po->Register("matcha-acoustic-model", &acoustic_model, + "Path to matcha acoustic model"); + po->Register("matcha-vocoder", &vocoder, "Path to matcha vocoder"); + po->Register("matcha-lexicon", &lexicon, + "Path to lexicon.txt for Matcha models"); + po->Register("matcha-tokens", &tokens, + "Path to tokens.txt for Matcha models"); + po->Register("matcha-data-dir", &data_dir, + "Path to the directory containing dict for espeak-ng. If it is " + "given, --matcha-lexicon is ignored."); + po->Register("matcha-dict-dir", &dict_dir, + "Path to the directory containing dict for jieba. Used only for " + "Chinese TTS models using jieba"); + po->Register("matcha-noise-scale", &noise_scale, + "noise_scale for Matcha models"); + po->Register("matcha-length-scale", &length_scale, + "Speech speed. Larger->Slower; Smaller->faster."); +} + +bool OfflineTtsMatchaModelConfig::Validate() const { + if (acoustic_model.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-acoustic-model"); + return false; + } + + if (!FileExists(acoustic_model)) { + SHERPA_ONNX_LOGE("--matcha-acoustic-model: '%s' does not exist", + acoustic_model.c_str()); + return false; + } + + if (vocoder.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-vocoder"); + return false; + } + + if (!FileExists(vocoder)) { + SHERPA_ONNX_LOGE("--matcha-vocoder: '%s' does not exist", vocoder.c_str()); + return false; + } + + if (tokens.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-tokens"); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("--matcha-tokens: '%s' does not exist", tokens.c_str()); + return false; + } + + if (!data_dir.empty()) { + if (!FileExists(data_dir + "/phontab")) { + SHERPA_ONNX_LOGE( + "'%s/phontab' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phonindex")) { + SHERPA_ONNX_LOGE( + "'%s/phonindex' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/phondata")) { + SHERPA_ONNX_LOGE( + "'%s/phondata' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + + if (!FileExists(data_dir + "/intonations")) { + SHERPA_ONNX_LOGE( + "'%s/intonations' does not exist. Please check --matcha-data-dir", + data_dir.c_str()); + return false; + } + } + + if (!dict_dir.empty()) { + std::vector required_files = { + "jieba.dict.utf8", "hmm_model.utf8", "user.dict.utf8", + "idf.utf8", "stop_words.utf8", + }; + + for (const auto &f : required_files) { + if (!FileExists(dict_dir + "/" + f)) { + SHERPA_ONNX_LOGE( + "'%s/%s' does not exist. Please check --matcha-dict-dir", + dict_dir.c_str(), f.c_str()); + return false; + } + } + + // we require that --matcha-lexicon is not empty + if (lexicon.empty()) { + SHERPA_ONNX_LOGE("Please provide --matcha-lexicon"); + return false; + } + + if (!FileExists(lexicon)) { + SHERPA_ONNX_LOGE("--matcha-lexicon: '%s' does not exist", + lexicon.c_str()); + return false; + } + } + + return true; +} + +std::string OfflineTtsMatchaModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineTtsMatchaModelConfig("; + os << "acoustic_model=\"" << acoustic_model << "\", "; + os << "vocoder=\"" << vocoder << "\", "; + os << "lexicon=\"" << lexicon << "\", "; + os << "tokens=\"" << tokens << "\", "; + os << "data_dir=\"" << data_dir << "\", "; + os << "dict_dir=\"" << dict_dir << "\", "; + os << "noise_scale=" << noise_scale << ", "; + os << "length_scale=" << length_scale << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-matcha-model-config.h b/sherpa-onnx/csrc/offline-tts-matcha-model-config.h new file mode 100644 index 000000000..f367a7e05 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-matcha-model-config.h @@ -0,0 +1,56 @@ +// sherpa-onnx/csrc/offline-tts-matcha-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineTtsMatchaModelConfig { + std::string acoustic_model; + std::string vocoder; + std::string lexicon; + std::string tokens; + + // If data_dir is given, lexicon is ignored + // data_dir is for piper-phonemizer, which uses espeak-ng + std::string data_dir; + + // Used for Chinese TTS models using jieba + std::string dict_dir; + + float noise_scale = 1; + float length_scale = 1; + + OfflineTtsMatchaModelConfig() = default; + + OfflineTtsMatchaModelConfig(const std::string &acoustic_model, + const std::string &vocoder, + const std::string &lexicon, + const std::string &tokens, + const std::string &data_dir, + const std::string &dict_dir, + float noise_scale = 1.0, float length_scale = 1) + : acoustic_model(acoustic_model), + vocoder(vocoder), + lexicon(lexicon), + tokens(tokens), + data_dir(data_dir), + dict_dir(dict_dir), + noise_scale(noise_scale), + length_scale(length_scale) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-tts-matcha-model-metadata.h b/sherpa-onnx/csrc/offline-tts-matcha-model-metadata.h new file mode 100644 index 000000000..3147985dd --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-matcha-model-metadata.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-tts-matcha-model-metadata.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_METADATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_METADATA_H_ + +#include +#include + +namespace sherpa_onnx { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineTtsMatchaModelMetaData { + int32_t sample_rate = 0; + int32_t num_speakers = 0; + int32_t version = 1; + int32_t jieba = 0; + int32_t espeak = 0; + int32_t use_eos_bos = 0; + int32_t pad_id = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_METADATA_H_ diff --git a/sherpa-onnx/csrc/offline-tts-matcha-model.cc b/sherpa-onnx/csrc/offline-tts-matcha-model.cc new file mode 100644 index 000000000..066dbd21a --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-matcha-model.cc @@ -0,0 +1,198 @@ +// sherpa-onnx/csrc/offline-tts-matcha-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-tts-matcha-model.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class OfflineTtsMatchaModel::Impl { + public: + explicit Impl(const OfflineTtsModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config.matcha.acoustic_model); + Init(buf.data(), buf.size()); + } + + template + Impl(Manager *mgr, const OfflineTtsModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config.matcha.acoustic_model); + Init(buf.data(), buf.size()); + } + + const OfflineTtsMatchaModelMetaData &GetMetaData() const { + return meta_data_; + } + + Ort::Value Run(Ort::Value x, int64_t sid, float speed) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector x_shape = x.GetTensorTypeAndShapeInfo().GetShape(); + if (x_shape[0] != 1) { + SHERPA_ONNX_LOGE("Support only batch_size == 1. Given: %d", + static_cast(x_shape[0])); + exit(-1); + } + + int64_t len = x_shape[1]; + int64_t len_shape = 1; + + Ort::Value x_length = + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); + + int64_t scale_shape = 1; + float noise_scale = config_.matcha.noise_scale; + float length_scale = config_.matcha.length_scale; + + if (speed != 1 && speed > 0) { + length_scale = 1. / speed; + } + + Ort::Value noise_scale_tensor = + Ort::Value::CreateTensor(memory_info, &noise_scale, 1, &scale_shape, 1); + + Ort::Value length_scale_tensor = Ort::Value::CreateTensor( + memory_info, &length_scale, 1, &scale_shape, 1); + + Ort::Value sid_tensor = + Ort::Value::CreateTensor(memory_info, &sid, 1, &scale_shape, 1); + + std::vector inputs; + inputs.reserve(5); + inputs.push_back(std::move(x)); + inputs.push_back(std::move(x_length)); + inputs.push_back(std::move(noise_scale_tensor)); + inputs.push_back(std::move(length_scale_tensor)); + + if (input_names_.size() == 5 && input_names_.back() == "sid") { + inputs.push_back(std::move(sid_tensor)); + } + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return std::move(out[0]); + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---matcha model---\n"; + PrintModelMetadata(os, meta_data); + + os << "----------input names----------\n"; + int32_t i = 0; + for (const auto &s : input_names_) { + os << i << " " << s << "\n"; + ++i; + } + os << "----------output names----------\n"; + i = 0; + for (const auto &s : output_names_) { + os << i << " " << s << "\n"; + ++i; + } + +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.version, "version", 1); + SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "n_speakers"); + SHERPA_ONNX_READ_META_DATA(meta_data_.jieba, "jieba"); + SHERPA_ONNX_READ_META_DATA(meta_data_.espeak, "has_espeak"); + SHERPA_ONNX_READ_META_DATA(meta_data_.use_eos_bos, "use_eos_bos"); + SHERPA_ONNX_READ_META_DATA(meta_data_.pad_id, "pad_id"); + } + + private: + OfflineTtsModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineTtsMatchaModelMetaData meta_data_; +}; + +OfflineTtsMatchaModel::OfflineTtsMatchaModel( + const OfflineTtsModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineTtsMatchaModel::OfflineTtsMatchaModel( + Manager *mgr, const OfflineTtsModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +OfflineTtsMatchaModel::~OfflineTtsMatchaModel() = default; + +const OfflineTtsMatchaModelMetaData &OfflineTtsMatchaModel::GetMetaData() + const { + return impl_->GetMetaData(); +} + +Ort::Value OfflineTtsMatchaModel::Run(Ort::Value x, int64_t sid /*= 0*/, + float speed /*= 1.0*/) const { + return impl_->Run(std::move(x), sid, speed); +} + +#if __ANDROID_API__ >= 9 +template OfflineTtsMatchaModel::OfflineTtsMatchaModel( + AAssetManager *mgr, const OfflineTtsModelConfig &config); +#endif + +#if __OHOS__ +template OfflineTtsMatchaModel::OfflineTtsMatchaModel( + NativeResourceManager *mgr, const OfflineTtsModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-tts-matcha-model.h b/sherpa-onnx/csrc/offline-tts-matcha-model.h new file mode 100644 index 000000000..5b02ec9b3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-tts-matcha-model.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/offline-tts-matcha-model.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_H_ + +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-tts-matcha-model-metadata.h" +#include "sherpa-onnx/csrc/offline-tts-model-config.h" + +namespace sherpa_onnx { + +class OfflineTtsMatchaModel { + public: + ~OfflineTtsMatchaModel(); + + explicit OfflineTtsMatchaModel(const OfflineTtsModelConfig &config); + + template + OfflineTtsMatchaModel(Manager *mgr, const OfflineTtsModelConfig &config); + + // Return a float32 tensor containing the mel + // of shape (batch_size, mel_dim, num_frames) + Ort::Value Run(Ort::Value x, int64_t sid = 0, float speed = 1.0) const; + + const OfflineTtsMatchaModelMetaData &GetMetaData() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_TTS_MATCHA_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-tts-model-config.cc b/sherpa-onnx/csrc/offline-tts-model-config.cc index f38c681a0..4af179a4b 100644 --- a/sherpa-onnx/csrc/offline-tts-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-model-config.cc @@ -10,6 +10,7 @@ namespace sherpa_onnx { void OfflineTtsModelConfig::Register(ParseOptions *po) { vits.Register(po); + matcha.Register(po); po->Register("num-threads", &num_threads, "Number of threads to run the neural network"); @@ -27,7 +28,11 @@ bool OfflineTtsModelConfig::Validate() const { return false; } - return vits.Validate(); + if (!vits.model.empty()) { + return vits.Validate(); + } + + return matcha.Validate(); } std::string OfflineTtsModelConfig::ToString() const { @@ -35,6 +40,7 @@ std::string OfflineTtsModelConfig::ToString() const { os << "OfflineTtsModelConfig("; os << "vits=" << vits.ToString() << ", "; + os << "matcha=" << matcha.ToString() << ", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; os << "provider=\"" << provider << "\")"; diff --git a/sherpa-onnx/csrc/offline-tts-model-config.h b/sherpa-onnx/csrc/offline-tts-model-config.h index bee50ba12..232686960 100644 --- a/sherpa-onnx/csrc/offline-tts-model-config.h +++ b/sherpa-onnx/csrc/offline-tts-model-config.h @@ -7,6 +7,7 @@ #include +#include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h" #include "sherpa-onnx/csrc/offline-tts-vits-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -14,6 +15,7 @@ namespace sherpa_onnx { struct OfflineTtsModelConfig { OfflineTtsVitsModelConfig vits; + OfflineTtsMatchaModelConfig matcha; int32_t num_threads = 1; bool debug = false; @@ -22,9 +24,11 @@ struct OfflineTtsModelConfig { OfflineTtsModelConfig() = default; OfflineTtsModelConfig(const OfflineTtsVitsModelConfig &vits, + const OfflineTtsMatchaModelConfig &matcha, int32_t num_threads, bool debug, const std::string &provider) : vits(vits), + matcha(matcha), num_threads(num_threads), debug(debug), provider(provider) {} diff --git a/sherpa-onnx/csrc/offline-tts-vits-impl.h b/sherpa-onnx/csrc/offline-tts-vits-impl.h index 560576357..1cc8d5f95 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-impl.h +++ b/sherpa-onnx/csrc/offline-tts-vits-impl.h @@ -156,17 +156,31 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { int32_t num_speakers = meta_data.num_speakers; if (num_speakers == 0 && sid != 0) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This is a single-speaker model and supports only sid 0. Given sid: " + "%{public}d. sid is ignored", + static_cast(sid)); +#else SHERPA_ONNX_LOGE( "This is a single-speaker model and supports only sid 0. Given sid: " "%d. sid is ignored", static_cast(sid)); +#endif } if (num_speakers != 0 && (sid >= num_speakers || sid < 0)) { +#if __OHOS__ + SHERPA_ONNX_LOGE( + "This model contains only %{public}d speakers. sid should be in the " + "range [%{public}d, %{public}d]. Given: %{public}d. Use sid=0", + num_speakers, 0, num_speakers - 1, static_cast(sid)); +#else SHERPA_ONNX_LOGE( "This model contains only %d speakers. sid should be in the range " "[%d, %d]. Given: %d. Use sid=0", num_speakers, 0, num_speakers - 1, static_cast(sid)); +#endif sid = 0; } @@ -389,8 +403,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) { frontend_ = std::make_unique( config_.model.vits.lexicon, config_.model.vits.tokens, - config_.model.vits.dict_dir, model_->GetMetaData(), - config_.model.debug); + config_.model.vits.dict_dir, config_.model.debug); } else if ((meta_data.is_piper || meta_data.is_coqui || meta_data.is_icefall) && !config_.model.vits.data_dir.empty()) { @@ -410,17 +423,6 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl { } } - std::vector AddBlank(const std::vector &x) const { - // we assume the blank ID is 0 - std::vector buffer(x.size() * 2 + 1); - int32_t i = 1; - for (auto k : x) { - buffer[i] = k; - i += 2; - } - return buffer; - } - GeneratedAudio Process(const std::vector> &tokens, const std::vector> &tones, int32_t sid, float speed) const { diff --git a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc index 9eb5b64a3..17c63460b 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model-config.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model-config.cc @@ -51,25 +51,30 @@ bool OfflineTtsVitsModelConfig::Validate() const { if (!data_dir.empty()) { if (!FileExists(data_dir + "/phontab")) { - SHERPA_ONNX_LOGE("'%s/phontab' does not exist. Skipping test", - data_dir.c_str()); + SHERPA_ONNX_LOGE( + "'%s/phontab' does not exist. Please check --vits-data-dir", + data_dir.c_str()); return false; } if (!FileExists(data_dir + "/phonindex")) { - SHERPA_ONNX_LOGE("'%s/phonindex' does not exist. Skipping test", - data_dir.c_str()); + SHERPA_ONNX_LOGE( + "'%s/phonindex' does not exist. Please check --vits-data-dir", + data_dir.c_str()); return false; } if (!FileExists(data_dir + "/phondata")) { - SHERPA_ONNX_LOGE("'%s/phondata' does not exist. Skipping test", - data_dir.c_str()); + SHERPA_ONNX_LOGE( + "'%s/phondata' does not exist. Please check --vits-data-dir", + data_dir.c_str()); return false; } if (!FileExists(data_dir + "/intonations")) { - SHERPA_ONNX_LOGE("'%s/intonations' does not exist.", data_dir.c_str()); + SHERPA_ONNX_LOGE( + "'%s/intonations' does not exist. Please check --vits-data-dir", + data_dir.c_str()); return false; } } @@ -82,8 +87,8 @@ bool OfflineTtsVitsModelConfig::Validate() const { for (const auto &f : required_files) { if (!FileExists(dict_dir + "/" + f)) { - SHERPA_ONNX_LOGE("'%s/%s' does not exist.", dict_dir.c_str(), - f.c_str()); + SHERPA_ONNX_LOGE("'%s/%s' does not exist. Please check vits-dict-dir", + dict_dir.c_str(), f.c_str()); return false; } } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index eb605a7bd..3587a109d 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -174,7 +174,7 @@ class OfflineTtsVitsModel::Impl { SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.bos_id, "bos_id", 0); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.eos_id, "eos_id", 0); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.use_eos_bos, - "use_eos_bos", 0); + "use_eos_bos", 1); SHERPA_ONNX_READ_META_DATA_WITH_DEFAULT(meta_data_.pad_id, "pad_id", 0); std::string comment; @@ -362,7 +362,7 @@ Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, int64_t sid /*=0*/, Ort::Value OfflineTtsVitsModel::Run(Ort::Value x, Ort::Value tones, int64_t sid /*= 0*/, - float speed /*= 1.0*/) { + float speed /*= 1.0*/) const { return impl_->Run(std::move(x), std::move(tones), sid, speed); } diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.h b/sherpa-onnx/csrc/offline-tts-vits-model.h index a880934ef..30e4205dc 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.h +++ b/sherpa-onnx/csrc/offline-tts-vits-model.h @@ -37,7 +37,7 @@ class OfflineTtsVitsModel { // This is for MeloTTS Ort::Value Run(Ort::Value x, Ort::Value tones, int64_t sid = 0, - float speed = 1.0); + float speed = 1.0) const; const OfflineTtsVitsModelMetaData &GetMetaData() const; diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index cd9eb516f..a33594f0b 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -273,4 +273,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } +Ort::SessionOptions GetSessionOptions(int32_t num_threads, + const std::string &provider_str) { + return GetSessionOptionsImpl(num_threads, provider_str); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index e19db6c20..131023e88 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -26,6 +26,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config, const std::string &model_type); +Ort::SessionOptions GetSessionOptions(int32_t num_threads, + const std::string &provider_str); + template Ort::SessionOptions GetSessionOptions(const T &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc index 1ab8b68de..92feb8eac 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-tts.cc @@ -72,6 +72,10 @@ or details. exit(EXIT_FAILURE); } + if (config.model.debug) { + fprintf(stderr, "%s\n", config.model.ToString().c_str()); + } + if (!config.Validate()) { fprintf(stderr, "Errors in config!\n"); exit(EXIT_FAILURE); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 21f77f29d..38d32de50 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -54,6 +54,7 @@ endif() if(SHERPA_ONNX_ENABLE_TTS) list(APPEND srcs + offline-tts-matcha-model-config.cc offline-tts-model-config.cc offline-tts-vits-model-config.cc offline-tts.cc diff --git a/sherpa-onnx/python/csrc/offline-tts-matcha-model-config.cc b/sherpa-onnx/python/csrc/offline-tts-matcha-model-config.cc new file mode 100644 index 000000000..2c932174e --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-tts-matcha-model-config.cc @@ -0,0 +1,37 @@ +// sherpa-onnx/python/csrc/offline-tts-matcha-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h" + +#include + +#include "sherpa-onnx/csrc/offline-tts-matcha-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineTtsMatchaModelConfig(py::module *m) { + using PyClass = OfflineTtsMatchaModelConfig; + + py::class_(*m, "OfflineTtsMatchaModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("acoustic_model"), py::arg("vocoder"), py::arg("lexicon"), + py::arg("tokens"), py::arg("data_dir") = "", + py::arg("dict_dir") = "", py::arg("noise_scale") = 1.0, + py::arg("length_scale") = 1.0) + .def_readwrite("acoustic_model", &PyClass::acoustic_model) + .def_readwrite("vocoder", &PyClass::vocoder) + .def_readwrite("lexicon", &PyClass::lexicon) + .def_readwrite("tokens", &PyClass::tokens) + .def_readwrite("data_dir", &PyClass::data_dir) + .def_readwrite("dict_dir", &PyClass::dict_dir) + .def_readwrite("noise_scale", &PyClass::noise_scale) + .def_readwrite("length_scale", &PyClass::length_scale) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h b/sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h new file mode 100644 index 000000000..09b0c5c98 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineTtsMatchaModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_TTS_MATCHA_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-tts-model-config.cc b/sherpa-onnx/python/csrc/offline-tts-model-config.cc index e5e86d968..fd19e4f85 100644 --- a/sherpa-onnx/python/csrc/offline-tts-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-tts-model-config.cc @@ -7,22 +7,26 @@ #include #include "sherpa-onnx/csrc/offline-tts-model-config.h" +#include "sherpa-onnx/python/csrc/offline-tts-matcha-model-config.h" #include "sherpa-onnx/python/csrc/offline-tts-vits-model-config.h" namespace sherpa_onnx { void PybindOfflineTtsModelConfig(py::module *m) { PybindOfflineTtsVitsModelConfig(m); + PybindOfflineTtsMatchaModelConfig(m); using PyClass = OfflineTtsModelConfig; py::class_(*m, "OfflineTtsModelConfig") .def(py::init<>()) - .def(py::init(), - py::arg("vits"), py::arg("num_threads") = 1, + py::arg("vits"), py::arg("matcha"), py::arg("num_threads") = 1, py::arg("debug") = false, py::arg("provider") = "cpu") .def_readwrite("vits", &PyClass::vits) + .def_readwrite("matcha", &PyClass::matcha) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) .def_readwrite("provider", &PyClass::provider) diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 2d5e456dc..330c8d2df 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -20,6 +20,7 @@ OfflineStream, OfflineTts, OfflineTtsConfig, + OfflineTtsMatchaModelConfig, OfflineTtsModelConfig, OfflineTtsVitsModelConfig, OfflineZipformerAudioTaggingModelConfig,