-
Notifications
You must be signed in to change notification settings - Fork 1
/
prepare_dataset.py
146 lines (130 loc) · 3.48 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import typing as tp
from pathlib import Path
import musdb
import torch
from omegaconf import OmegaConf, DictConfig
from tqdm import tqdm
from data import SAD
parser = argparse.ArgumentParser()
parser.add_argument(
'-i',
'--input-dir',
type=str,
required=False,
default= "D://dataset//musdb18hq",
help="Path to directory with musdb18 dataset"
)
parser.add_argument(
'-o',
'--output-dir',
type=str,
required=False,
default= "D://Project//band-split-rope-transformer//files",
help="Path to directory where output .txt file is saved"
)
parser.add_argument(
'--subset',
type=str,
required=False,
default='test',
help="Train/test subset of dataset to process"
)
parser.add_argument(
'--split',
type=str,
required=False,
default='train',
help="Train/valid split of train dataset. Used if subset=train"
)
parser.add_argument(
'--sad-cfg-path',
type=str,
required=False,
default="./conf/sad/default.yaml",
help="Path to Source Activity Detection config file"
)
parser.add_argument(
'-t',
'--targets',
nargs='+',
required=False,
default=["vocals"],
help="Target source. SAD will save salient fragments of vocal audio."
)
args = parser.parse_args()
def prepare_save_line(
track_name: str,
start_indices: torch.Tensor,
window_size: int
) -> tp.Iterable[str]:
"""
Creates string in format TRACK_NAME START_INDEX END_INDEX.
"""
for i in start_indices:
save_line = f"{track_name}\t{i}\t{i + window_size}\n"
yield save_line
def run_program(
file_path: Path,
target: str,
db: musdb.DB,
sad: SAD,
) -> None:
"""
Saves track's name and fragments indices to provided .txt file.
"""
with open(file_path, 'w') as wf:
for track in tqdm(db):
# get audio data and transform to torch.Tensor
y = torch.tensor(
track.targets[target].audio.T,
dtype=torch.float32
)
# find indices of salient segments
indices = sad.calculate_salient_indices(y)
# write to file
for line in prepare_save_line(track.name, indices, sad.window_size):
wf.write(line)
return None
def main(
db_dir: str,
save_dir: str,
subset: str,
split: tp.Optional[str],
targets: tp.List[str],
sad_cfg_path: DictConfig
) -> None:
# initialize MUSDB parser
split = None if subset == 'test' else split
db = musdb.DB(
root=db_dir,
subsets=subset,
split=split,
download=False,
is_wav=True,
)
# initialize Source Activity Detector
sad_cfg = OmegaConf.load(sad_cfg_path)
sad = SAD(**sad_cfg)
# initialize directories where to save indices
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
for target in targets:
if subset == split == 'train':
file_path = save_dir / f"{target}_train.txt"
elif subset == 'train' and split == 'valid':
file_path = save_dir / f"{target}_valid.txt"
else:
file_path = save_dir / f"{target}_test.txt"
# segment data and save indices to .txt file
run_program(file_path, target, db, sad)
return None
if __name__ == '__main__':
main(
args.input_dir,
args.output_dir,
args.subset,
args.split,
args.targets,
args.sad_cfg_path,
)