Skip to content

Commit

Permalink
Support English for MeloTTS models. (#1134)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jul 15, 2024
1 parent fa07bbc commit 9548541
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/windows-x64-jni.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [windows-latest]
os: [windows-2019]

steps:
- uses: actions/checkout@v4
Expand Down
43 changes: 22 additions & 21 deletions scripts/melo-tts/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from melo.api import TTS
from melo.text import language_id_map, language_tone_start_map
from melo.text.chinese import pinyin_to_symbol_map
from melo.text.english import eng_dict, refine_syllables
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
from melo.text.symbols import language_tone_start_map

for k, v in pinyin_to_symbol_map.items():
if isinstance(v, list):
break
pinyin_to_symbol_map[k] = v.split()


Expand Down Expand Up @@ -79,6 +83,16 @@ def generate_lexicon():
word_dict = pinyin_dict.pinyin_dict
phrases = phrases_dict.phrases_dict
with open("lexicon.txt", "w", encoding="utf-8") as f:
for word in eng_dict:
phones, tones = refine_syllables(eng_dict[word])
tones = [t + language_tone_start_map["EN"] for t in tones]
tones = [str(t) for t in tones]

phones = " ".join(phones)
tones = " ".join(tones)

f.write(f"{word.lower()} {phones} {tones}\n")

for key in word_dict:
if not (0x4E00 <= key <= 0x9FA5):
continue
Expand Down Expand Up @@ -125,15 +139,13 @@ class ModelWrapper(torch.nn.Module):
def __init__(self, model: "SynthesizerTrn"):
super().__init__()
self.model = model
self.lang_id = language_id_map[model.language]

def forward(
self,
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
Expand All @@ -147,7 +159,11 @@ def forward(
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
sid: an integer
"""
return self.model.infer(
bert = torch.zeros(x.shape[0], 1024, x.shape[1], dtype=torch.float32)
ja_bert = torch.zeros(x.shape[0], 768, x.shape[1], dtype=torch.float32)
lang_id = torch.zeros_like(x)
lang_id[:, 1::2] = self.lang_id
return self.model.model.infer(
x=x,
x_lengths=x_lengths,
sid=sid,
Expand All @@ -169,27 +185,21 @@ def main():

generate_tokens(model.hps["symbols"])

torch_model = ModelWrapper(model.model)
torch_model = ModelWrapper(model)

opset_version = 13
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
print(x.shape)
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
sid = torch.tensor([1], dtype=torch.int64)
tones = torch.zeros_like(x)
lang_id = torch.ones_like(x)

noise_scale = torch.tensor([1.0], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)

bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)

x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
lang_id = lang_id.unsqueeze(0)
bert = bert.unsqueeze(0)
ja_bert = ja_bert.unsqueeze(0)

filename = "model.onnx"

Expand All @@ -199,9 +209,6 @@ def main():
x,
x_lengths,
tones,
lang_id,
bert,
ja_bert,
sid,
noise_scale,
length_scale,
Expand All @@ -213,9 +220,6 @@ def main():
"x",
"x_lengths",
"tones",
"lang_id",
"bert",
"ja_bert",
"sid",
"noise_scale",
"length_scale",
Expand All @@ -226,9 +230,6 @@ def main():
"x": {0: "N", 1: "L"},
"x_lengths": {0: "N"},
"tones": {0: "N", 1: "L"},
"lang_id": {0: "N", 1: "L"},
"bert": {0: "N", 2: "L"},
"ja_bert": {0: "N", 2: "L"},
"y": {0: "N", 1: "S", 2: "T"},
},
)
Expand Down
2 changes: 2 additions & 0 deletions scripts/melo-tts/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ echo "pwd: $PWD"

ls -lh

./show-info.py

head lexicon.txt
echo "---"
tail lexicon.txt
Expand Down
50 changes: 50 additions & 0 deletions scripts/melo-tts/show-info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)

import onnxruntime


def show(filename):
session_opts = onnxruntime.SessionOptions()
session_opts.log_severity_level = 3
sess = onnxruntime.InferenceSession(filename, session_opts)
for i in sess.get_inputs():
print(i)

print("-----")

for i in sess.get_outputs():
print(i)

meta = sess.get_modelmeta().custom_metadata_map
print("*****************************************")
print("meta\n", meta)


def main():
print("=========model==========")
show("./model.onnx")


if __name__ == "__main__":
main()

"""
=========model==========
NodeArg(name='x', type='tensor(int64)', shape=['N', 'L'])
NodeArg(name='x_lengths', type='tensor(int64)', shape=['N'])
NodeArg(name='tones', type='tensor(int64)', shape=['N', 'L'])
NodeArg(name='sid', type='tensor(int64)', shape=[1])
NodeArg(name='noise_scale', type='tensor(float)', shape=[1])
NodeArg(name='length_scale', type='tensor(float)', shape=[1])
NodeArg(name='noise_scale_w', type='tensor(float)', shape=[1])
-----
NodeArg(name='y', type='tensor(float)', shape=['N', 'S', 'T'])
*****************************************
meta
{'description': 'MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai',
'model_type': 'melo-vits', 'license': 'MIT license', 'sample_rate': '44100', 'add_blank': '1',
'n_speakers': '1', 'bert_dim': '1024', 'language': 'Chinese + English',
'ja_bert_dim': '768', 'speaker_id': '1', 'comment': 'melo', 'lang_id': '3',
'tone_start': '0', 'url': 'https://github.com/myshell-ai/MeloTTS'}
"""
41 changes: 24 additions & 17 deletions scripts/melo-tts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, lexion_filename: str, tokens_filename: str):
tones = [int(t) for t in tones]

lexicon[word_or_phrase] = (phones, tones)
lexicon["呣"] = lexicon["母"]
lexicon["嗯"] = lexicon["恩"]
self.lexicon = lexicon

punctuation = ["!", "?", "…", ",", ".", "'", "-"]
Expand Down Expand Up @@ -98,20 +100,16 @@ def __init__(self, filename):
self.lang_id = int(meta["lang_id"])
self.sample_rate = int(meta["sample_rate"])

def __call__(self, x, tones, lang):
def __call__(self, x, tones):
"""
Args:
x: 1-D int64 torch tensor
tones: 1-D int64 torch tensor
lang: 1-D int64 torch tensor
"""
x = x.unsqueeze(0)
tones = tones.unsqueeze(0)
lang = lang.unsqueeze(0)

print(x.shape, tones.shape, lang.shape)
bert = torch.zeros(1, self.bert_dim, x.shape[-1])
ja_bert = torch.zeros(1, self.ja_bert_dim, x.shape[-1])
print(x.shape, tones.shape)
sid = torch.tensor([self.speaker_id], dtype=torch.int64)
noise_scale = torch.tensor([0.6], dtype=torch.float32)
length_scale = torch.tensor([1.0], dtype=torch.float32)
Expand All @@ -125,9 +123,6 @@ def __call__(self, x, tones, lang):
"x": x.numpy(),
"x_lengths": x_lengths.numpy(),
"tones": tones.numpy(),
"lang_id": lang.numpy(),
"bert": bert.numpy(),
"ja_bert": ja_bert.numpy(),
"sid": sid.numpy(),
"noise_scale": noise_scale.numpy(),
"noise_scale_w": noise_scale_w.numpy(),
Expand All @@ -140,34 +135,46 @@ def __call__(self, x, tones, lang):
def main():
lexicon = Lexicon(lexion_filename="./lexicon.txt", tokens_filename="./tokens.txt")

text = "永远相信,美好的事情即将发生。多音字测试, 银行,行不行?长沙长大"
text = "永远相信,美好的事情即将发生。"
s = jieba.cut(text, HMM=True)

phones, tones = lexicon.convert(s)

en_text = "how are you ?".split()

phones_en, tones_en = lexicon.convert(en_text)
phones += [0]
tones += [0]

phones += phones_en
tones += tones_en

text = "多音字测试, 银行,行不行?长沙长大"
s = jieba.cut(text, HMM=True)

phones2, tones2 = lexicon.convert(s)

phones += phones2
tones += tones2

model = OnnxModel("./model.onnx")
langs = [model.lang_id] * len(phones)

if model.add_blank:
new_phones = [0] * (2 * len(phones) + 1)
new_tones = [0] * (2 * len(tones) + 1)
new_langs = [0] * (2 * len(langs) + 1)

new_phones[1::2] = phones
new_tones[1::2] = tones
new_langs[1::2] = langs

phones = new_phones
tones = new_tones
langs = new_langs

phones = torch.tensor(phones, dtype=torch.int64)
tones = torch.tensor(tones, dtype=torch.int64)
langs = torch.tensor(langs, dtype=torch.int64)

print(phones.shape, tones.shape, langs.shape)
print(phones.shape, tones.shape)

y = model(x=phones, tones=tones, lang=langs)
y = model(x=phones, tones=tones)
sf.write("./test.wav", y, model.sample_rate)


Expand Down

0 comments on commit 9548541

Please sign in to comment.