forked from RetroCirce/Zero_Shot_Audio_Source_Separation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_processor.py
179 lines (165 loc) · 7.04 KB
/
data_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Ke Chen
# Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data
# The dataset classes
import numpy as np
import torch
import logging
import os
import sys
import h5py
import csv
import time
import random
import json
from datetime import datetime
from utils import int16_to_float32
from torch.utils.data import Dataset, Sampler
# output the dict["index"].key form to save the memory in multi-GPU training
def reverse_dict(data_path, sed_path, output_dir):
# filename
waveform_dir = os.path.join(output_dir, "audioset_eval_waveform_balanced.h5")
sed_dir = os.path.join(output_dir, "audioset_eval_sed_balanced.h5")
# load data
logging.info("Write Data...............")
h_data = h5py.File(data_path, "r")
h_sed = h5py.File(sed_path, "r")
audio_num = len(h_data["waveform"])
assert len(h_data["waveform"]) == len(h_sed["sed_vector"]), "waveform and sed should be in the same length"
with h5py.File(waveform_dir, 'w') as hw:
for i in range(audio_num):
hw.create_dataset(str(i), data=int16_to_float32(h_data['waveform'][i]), dtype=np.float32)
logging.info("Write Data Succeed...............")
logging.info("Write Sed...............")
with h5py.File(sed_dir, 'w') as hw:
for i in range(audio_num):
hw.create_dataset(str(i), data=h_sed['sed_vector'][i], dtype=np.float32)
logging.info("Write Sed Succeed...............")
# A dataset for handling musdb
class MusdbDataset(Dataset):
def __init__(self, tracks):
self.tracks = tracks
self.dataset_len = len(tracks)
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
index: the index number
Return:
track: [mixture + n_sources, n_samples]
"""
return self.tracks[index]
def __len__(self):
return self.dataset_len
class InferDataset(Dataset):
def __init__(self, tracks):
self.tracks = tracks
self.dataset_len = len(tracks)
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
index: the index number
Return:
track: [mixture + n_sources, n_samples]
"""
return self.tracks[index]
def __len__(self):
return self.dataset_len
# polished LGSPDataset, the main dataset for procssing the audioset files
class LGSPDataset(Dataset):
def __init__(self, index_path, idc, config, factor = 3, eval_mode = False):
self.index_path = index_path
self.fp = h5py.File(index_path, "r")
self.config = config
self.idc = idc
self.factor = factor
self.classes_num = self.config.classes_num
self.eval_mode = eval_mode
self.total_size = int(len(self.fp["audio_name"]) * self.factor)
self.generate_queue()
logging.info("total dataset size: %d" %(self.total_size))
logging.info("class num: %d" %(self.classes_num))
def generate_queue(self):
self.queue = []
self.class_queue = []
if self.config.debug:
self.total_size = 1000
if self.config.balanced_data:
while len(self.queue) < self.total_size * 2:
if self.eval_mode:
if len(self.config.eval_list) == 0:
class_set = [*range(self.classes_num)]
else:
class_set = self.config.eval_list[:]
else:
class_set = [*range(self.classes_num)]
class_set = list(set(class_set) - set(self.config.eval_list))
random.shuffle(class_set)
self.queue += [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in class_set]
self.class_queue += class_set[:]
self.queue = self.queue[:self.total_size * 2]
self.class_queue = self.class_queue[:self.total_size * 2]
self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
assert len(self.queue) == self.total_size, "generate data error!!"
else:
if self.eval_mode:
if len(self.config.eval_list) == 0:
class_set = [*range(self.classes_num)]
else:
class_set = self.config.eval_list[:]
else:
class_set = [*range(self.classes_num)]
class_set = list(set(class_set) - set(self.config.eval_list))
self.class_queue = random.choices(class_set, k = self.total_size * 2)
self.queue = [self.idc[d][random.randint(0, len(self.idc[d]) - 1)] for d in self.class_queue]
self.queue = [[self.queue[i],self.queue[i+1]] for i in range(0, self.total_size * 2, 2)]
self.class_queue = [[self.class_queue[i],self.class_queue[i+1]] for i in range(0, self.total_size * 2, 2)]
assert len(self.queue) == self.total_size, "generate data error!!"
logging.info("queue regenerated:%s" %(self.queue[-5:]))
def __getitem__(self, index):
"""Load waveform and target of an audio clip.
Args:
index: the index number
Return: {
"audio_name_1": str,
"waveform_1": (clip_samples,),
"class_id_1": int,
"audio_name_2": str,
"waveform_2": (clip_samples,),
"class_id_2": int,
...
"check_num": int
}
"""
# put the right index here!!!
data_dict = {}
for k in range(2):
s_index = self.queue[index][k]
target = self.class_queue[index][k]
audio_name = self.fp["audio_name"][s_index].decode()
hdf5_path = self.fp["hdf5_path"][s_index].decode().replace("/home/tiger/DB/knut/data/audioset", self.config.dataset_path)
r_idx = self.fp["index_in_hdf5"][s_index]
with h5py.File(hdf5_path, "r") as f:
waveform = int16_to_float32(f["waveform"][r_idx])
data_dict["audio_name_" + str(k+1)] = audio_name
data_dict["waveform_" + str(k+1)] = waveform
data_dict["class_id_" + str(k+1)] = target
data_dict["check_num"] = str(self.queue[-5:])
return data_dict
def __len__(self):
return self.total_size
# only for test
class TestDataset(Dataset):
def __init__(self, dataset_size):
print("init")
self.dataset_size = dataset_size
self.base_num = 100
self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
def get_new_list(self):
self.base_num = random.randint(0,10)
print("base num changed:", self.base_num)
self.dicts = [(self.base_num + 2 * i, self.base_num + 2 * i + 1) for i in range(self.dataset_size)]
def __getitem__(self, index):
return self.dicts[index]
def __len__(self):
return self.dataset_size