-
Notifications
You must be signed in to change notification settings - Fork 645
/
prepare_for_training.py
112 lines (80 loc) · 3.66 KB
/
prepare_for_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2023, YOUDAO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
import shutil
import argparse
def main(args):
from os.path import join
data_dir = args.data_dir
exp_dir = args.exp_dir
os.makedirs(exp_dir, exist_ok=True)
info_dir = join(exp_dir, 'info')
prepare_info(data_dir, info_dir)
config_dir = join(exp_dir, 'config')
prepare_config(data_dir, info_dir, exp_dir, config_dir)
ckpt_dir = join(exp_dir, 'ckpt')
prepare_ckpt(data_dir, info_dir, ckpt_dir)
ROOT_DIR = os.path.dirname(os.path.abspath("__file__"))
def prepare_info(data_dir, info_dir):
import jsonlines
print('prepare_info: %s' %info_dir)
os.makedirs(info_dir, exist_ok=True)
for name in ["emotion", "energy", "pitch", "speed", "tokenlist"]:
shutil.copy(f"{ROOT_DIR}/data/youdao/text/{name}", f"{info_dir}/{name}")
d_speaker = {} # get all the speakers from datalist
with jsonlines.open(f"{data_dir}/train/datalist.jsonl") as reader:
for obj in reader:
speaker = obj["speaker"]
if not speaker in d_speaker:
d_speaker[speaker] = 1
else:
d_speaker[speaker] += 1
with open(f"{ROOT_DIR}/data/youdao/text/speaker2") as f, \
open(f"{info_dir}/speaker", "w") as fout:
for line in f:
speaker = line.strip()
if speaker in d_speaker:
print('warning: duplicate of speaker [%s] in [%s]' % (speaker, data_dir))
continue
fout.write(line.strip()+"\n")
for speaker in sorted(d_speaker.keys()):
fout.write(speaker + "\n")
def prepare_config(data_dir, info_dir, exp_dir, config_dir):
print('prepare_config: %s' %config_dir)
os.makedirs(config_dir, exist_ok=True)
with open(f"{ROOT_DIR}/config/template.py") as f, \
open(f"{config_dir}/config.py", "w") as fout:
for line in f:
fout.write(line.replace('<DATA_DIR>', data_dir).replace('<INFO_DIR>', info_dir).replace('<EXP_DIR>', exp_dir))
def prepare_ckpt(data_dir, info_dir, ckpt_dir):
print('prepare_ckpt: %s' %ckpt_dir)
os.makedirs(ckpt_dir, exist_ok=True)
with open(f"{info_dir}/speaker") as f:
speaker_list=[line.strip() for line in f]
assert len(speaker_list) >= 2014
gen_ckpt_path = f"{ROOT_DIR}/outputs/prompt_tts_open_source_joint/ckpt/g_00140000"
disc_ckpt_path = f"{ROOT_DIR}/outputs/prompt_tts_open_source_joint/ckpt/do_00140000"
gen_ckpt = torch.load(gen_ckpt_path, map_location="cpu")
speaker_embeddings = gen_ckpt["generator"]["am.spk_tokenizer.weight"].clone()
new_embedding = torch.randn((len(speaker_list)-speaker_embeddings.size(0), speaker_embeddings.size(1)))
gen_ckpt["generator"]["am.spk_tokenizer.weight"] = torch.cat([speaker_embeddings, new_embedding], dim=0)
torch.save(gen_ckpt, f"{ckpt_dir}/pretrained_generator")
shutil.copy(disc_ckpt_path, f"{ckpt_dir}/pretrained_discriminator")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument('--data_dir', type=str, required=True)
p.add_argument('--exp_dir', type=str, required=True)
args = p.parse_args()
main(args)