Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add c++ runtime for Matcha-TTS #1627

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,15 @@ list(APPEND sources

if(SHERPA_ONNX_ENABLE_TTS)
list(APPEND sources
hifigan-vocoder.cc
jieba-lexicon.cc
lexicon.cc
melo-tts-lexicon.cc
offline-tts-character-frontend.cc
offline-tts-frontend.cc
offline-tts-impl.cc
offline-tts-matcha-model-config.cc
offline-tts-matcha-model.cc
offline-tts-model-config.cc
offline-tts-vits-model-config.cc
offline-tts-vits-model.cc
Expand Down
107 changes: 107 additions & 0 deletions sherpa-onnx/csrc/hifigan-vocoder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// sherpa-onnx/csrc/hifigan-vocoder.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/hifigan-vocoder.h"

#include <string>
#include <utility>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif

#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"

namespace sherpa_onnx {

class HifiganVocoder::Impl {
public:
explicit Impl(int32_t num_threads, const std::string &provider,
const std::string &model)
: env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(num_threads, provider)),
allocator_{} {
auto buf = ReadFile(model);
Init(buf.data(), buf.size());
}

template <typename Manager>
explicit Impl(Manager *mgr, int32_t num_threads, const std::string &provider,
const std::string &model)
: env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(num_threads, provider)),
allocator_{} {
auto buf = ReadFile(mgr, model);
Init(buf.data(), buf.size());
}

Ort::Value Run(Ort::Value mel) const {
auto out = sess_->Run({}, input_names_ptr_.data(), &mel, 1,
output_names_ptr_.data(), output_names_ptr_.size());

return std::move(out[0]);
}

private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);

GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);

GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
}

private:
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;

std::unique_ptr<Ort::Session> sess_;

std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;

std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
};

HifiganVocoder::HifiganVocoder(int32_t num_threads, const std::string &provider,
const std::string &model)
: impl_(std::make_unique<Impl>(num_threads, provider, model)) {}

template <typename Manager>
HifiganVocoder::HifiganVocoder(Manager *mgr, int32_t num_threads,
const std::string &provider,
const std::string &model)
: impl_(std::make_unique<Impl>(mgr, num_threads, provider, model)) {}

HifiganVocoder::~HifiganVocoder() = default;

Ort::Value HifiganVocoder::Run(Ort::Value mel) const {
return impl_->Run(std::move(mel));
}

#if __ANDROID_API__ >= 9
template HifiganVocoder::HifiganVocoder(AAssetManager *mgr, int32_t num_threads,
const std::string &provider,
const std::string &model);
#endif

#if __OHOS__
template HifiganVocoder::HifiganVocoder(NativeResourceManager *mgr,
int32_t num_threads,
const std::string &provider,
const std::string &model);
#endif

} // namespace sherpa_onnx
38 changes: 38 additions & 0 deletions sherpa-onnx/csrc/hifigan-vocoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// sherpa-onnx/csrc/hifigan-vocoder.h
//
// Copyright (c) 2024 Xiaomi Corporation

#ifndef SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_
#define SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_

#include <memory>
#include <string>

#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

class HifiganVocoder {
public:
~HifiganVocoder();

HifiganVocoder(int32_t num_threads, const std::string &provider,
const std::string &model);

template <typename Manager>
HifiganVocoder(Manager *mgr, int32_t num_threads, const std::string &provider,
const std::string &model);

/** @param mel A float32 tensor of shape (batch_size, feat_dim, num_frames).
* @return Return a float32 tensor of shape (batch_size, num_samples).
*/
Ort::Value Run(Ort::Value mel) const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa_onnx

#endif // SHERPA_ONNX_CSRC_HIFIGAN_VOCODER_H_
16 changes: 5 additions & 11 deletions sherpa-onnx/csrc/jieba-lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ namespace sherpa_onnx {
class JiebaLexicon::Impl {
public:
Impl(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
const std::string &dict_dir, bool debug)
: debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
Expand Down Expand Up @@ -93,7 +92,7 @@ class JiebaLexicon::Impl {
}

this_sentence.insert(this_sentence.end(), ids.begin(), ids.end());
this_sentence.push_back(blank);
// this_sentence.push_back(blank);

if (w == "。" || w == "!" || w == "?" || w == ",") {
ans.emplace_back(std::move(this_sentence));
Expand Down Expand Up @@ -195,8 +194,6 @@ class JiebaLexicon::Impl {
// tokens.txt is saved in token2id_
std::unordered_map<std::string, int32_t> token2id_;

OfflineTtsVitsModelMetaData meta_data_;

std::unique_ptr<cppjieba::Jieba> jieba_;
bool debug_ = false;
};
Expand All @@ -205,11 +202,8 @@ JiebaLexicon::~JiebaLexicon() = default;

JiebaLexicon::JiebaLexicon(const std::string &lexicon,
const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
const std::string &dict_dir, bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, debug)) {}

std::vector<TokenIDs> JiebaLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
Expand Down
4 changes: 1 addition & 3 deletions sherpa-onnx/csrc/jieba-lexicon.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
#include <vector>

#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model-metadata.h"

namespace sherpa_onnx {

class JiebaLexicon : public OfflineTtsFrontend {
public:
~JiebaLexicon() override;
JiebaLexicon(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug);
const std::string &dict_dir, bool debug);

std::vector<TokenIDs> ConvertTextToTokenIds(
const std::string &text,
Expand Down
27 changes: 23 additions & 4 deletions sherpa-onnx/csrc/offline-tts-impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "sherpa-onnx/csrc/offline-tts-impl.h"

#include <memory>
#include <vector>

#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
Expand All @@ -15,21 +16,39 @@
#include "rawfile/raw_file_manager.h"
#endif

#include "sherpa-onnx/csrc/offline-tts-matcha-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-impl.h"

namespace sherpa_onnx {

std::vector<int64_t> OfflineTtsImpl::AddBlank(const std::vector<int64_t> &x,
int32_t blank_id /*= 0*/) const {
// we assume the blank ID is 0
std::vector<int64_t> buffer(x.size() * 2 + 1, blank_id);
int32_t i = 1;
for (auto k : x) {
buffer[i] = k;
i += 2;
}
return buffer;
}

std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(config);
if (!config.model.vits.model.empty()) {
return std::make_unique<OfflineTtsVitsImpl>(config);
}
return std::make_unique<OfflineTtsMatchaImpl>(config);
}

template <typename Manager>
std::unique_ptr<OfflineTtsImpl> OfflineTtsImpl::Create(
Manager *mgr, const OfflineTtsConfig &config) {
// TODO(fangjun): Support other types
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
if (!config.model.vits.model.empty()) {
return std::make_unique<OfflineTtsVitsImpl>(mgr, config);
}

return std::make_unique<OfflineTtsMatchaImpl>(mgr, config);
}

#if __ANDROID_API__ >= 9
Expand Down
4 changes: 4 additions & 0 deletions sherpa-onnx/csrc/offline-tts-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include <memory>
#include <string>
#include <vector>

#include "sherpa-onnx/csrc/offline-tts.h"

Expand All @@ -32,6 +33,9 @@ class OfflineTtsImpl {
// Number of supported speakers.
// If it supports only a single speaker, then it return 0 or 1.
virtual int32_t NumSpeakers() const = 0;

std::vector<int64_t> AddBlank(const std::vector<int64_t> &x,
int32_t blank_id = 0) const;
};

} // namespace sherpa_onnx
Expand Down
Loading
Loading