-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprepare_dataset.py
121 lines (88 loc) · 2.82 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import os
import pathlib
import re
from hanziconv import HanziConv
from sklearn.model_selection import train_test_split
def parse_args():
parser = argparse.ArgumentParser(
description="Lyrics dataset preparation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"data",
type=str,
help="Path to directory of raw lyrics data",
)
parser.add_argument(
"-v",
"--val",
type=float,
default=0.2,
help="Ratio of validation dataset",
)
parser.add_argument(
"-o",
"--output",
type=str,
default="lyrcis_dataset",
help="Path to output data",
)
args = parser.parse_args()
return args
def mkdir_p(folder_path):
pathlib.Path(folder_path).mkdir(parents=True, exist_ok=True)
def read_text(file_path):
with open(file_path, "r") as f:
return f.read()
def write_text(text, file_path):
with open(file_path, "w") as f:
f.write(text)
def transform_lyric(raw_lyric):
lyric = HanziConv.toTraditional(raw_lyric)
lyric = lyric.replace(" ", ",").replace("\n", "。")
return lyric
def read_lyrics(data_root):
lyrics = []
singer_names = [
f
for f in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, f))
]
for singer_name in singer_names:
singer_dir = os.path.join(data_root, singer_name)
song_names = [
f
for f in os.listdir(singer_dir)
if os.path.isfile(os.path.join(singer_dir, f))
and f.endswith(".txt")
]
for song_name in song_names:
raw_lyric = read_text(os.path.join(singer_dir, song_name))
lyric = transform_lyric(raw_lyric)
lyrics.append(lyric)
return lyrics
def build_data(lyrics, output_path, bos_token="<BOS>", eos_token="<EOS>"):
trans_lyrcis = []
for lyric in lyrics:
trans_lyric = str(lyric).strip()
trans_lyric = re.sub(r"\s", " ", trans_lyric)
trans_lyric = "{} {} {}".format(bos_token, trans_lyric, eos_token)
trans_lyrcis.append(trans_lyric)
data = "\n".join(trans_lyrcis)
write_text(data, output_path)
def generate_dataset(lyrics, val_ratio, output_dir, random_state=1):
lyrics_train, lyrics_val = train_test_split(
lyrics,
train_size=(1.0 - val_ratio),
random_state=random_state,
)
build_data(lyrics_train, os.path.join(output_dir, "train.txt"))
build_data(lyrics_val, os.path.join(output_dir, "val.txt"))
def main(args):
mkdir_p(args.output)
lyrics = read_lyrics(args.data)
generate_dataset(lyrics, args.val, args.output)
print("Generated dataset saved in {}".format(args.output))
if __name__ == "__main__":
main(parse_args())