-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
220 lines (176 loc) · 8.57 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from Player import *
from setups import *
from Model import mix
from Normalizer import *
from plots import *
from scipy.io.wavfile import write
from mir_eval.separation import bss_eval_sources
from datetime import datetime
from typing import Tuple
def main():
sims, data_sets = setups()
SDRR = []
SARR = []
SIRR = []
dir_sims, dir_plots = create_folders()
# Iterate over simulations
for sim in sims:
print('\033[35mSimulation \'{}\' started\033[0m'.format(sim['name']))
# Create folders to save data
dir_sim, dir_sim_filtered, dir_sim_mixed, dir_sim_unmixed, dir_sim_plots, dir_sim_box_plots = sim_create_folders(sim, dir_sims)
# 1. Load source signals
for data_set in sim['data_sets']:
print('\t\033[35mDataSet \'{}\' \033[0m'.format(data_set['name']))
S = []
for si, wav in enumerate(data_set['data']):
if si >= sim['sources']:
break
if type(wav) == str:
S.append(load_wav(wav, data_set['fs']))
else:
S.append(wav)
# 2. Normalize & format source signals
S = form_source_matrix(S)
S = normalize_rowwise(np.array(S))
data_set['audio_duration'] = round(S.shape[1] / data_set['fs'], 1)
# 3. Perform environment simulation (mix signals)
print('\t\t\033[35mMixing signals...\033[0m')
filtered, mixed, sim = mix(S, sim, data_set)
# Check rir
# room = sim['mix_additional_outputs']['room_object']
# L = min([min(map(len, x)) for x in room.rir])
# rir = room.rir
# for mrs in rir:
# for sr in mrs:
# sr.resize((L,), refcheck=False)
# rir = np.array(rir)
# engine = find_engine()
# engine.check_rir(matlab.double(initializer=rir.tolist(), is_complex=True))
# 4. Normalize filtered & mixed arrays
mixed = normalize(mixed)
for f in filtered:
for i in range(f.shape[0]):
f[i] = normalize(f[i])
sim['filtered'] = filtered
sim['mixed'] = mixed
# 4.1. Save filtered & mixed plots
pr = "{}_".format(data_set['name'])
plot_original(S, dir_sim_mixed, pr, S.shape[1])
plot_filtered(filtered, dir_sim_filtered, pr, S.shape[1])
plot_mixed(mixed, dir_sim_mixed, pr, S.shape[1])
# 4.2. Save filtered & mixed to wav
for file_name, f in zip(data_set['file_names'], filtered):
for mi, m in enumerate(f):
write("{}/{}_{}_mic_{}.wav".format(dir_sim_filtered, data_set['name'], file_name, mi), data_set['fs'], np.float32(m))
for mi, m in enumerate(mixed):
write("{}/{}_mic_{}.wav".format(dir_sim_mixed, data_set['name'], mi), data_set['fs'], np.float32(m))
# 4.3. Create list of chunks (online version only)
if sim['run_type'] == 'online':
mixed_queue = rework_conv(mixed, sim)
else:
mixed_queue = []
SDR_temp = []
SIR_temp = []
SAR_temp = []
# 5. Run algorithms
print('\t\t\033[35mSeparating {}...\033[0m'.format('(chunk_size={})'.format(sim['chunk_size']) if 'chunk_size' in sim else ''))
for alg in sim['algs']:
# Make room object available for algorithms (needed for beamforming)
if 'mix_additional_outputs' in sim:
if 'room_object' in sim['mix_additional_outputs']:
alg['options']['room_object'] = sim['mix_additional_outputs']['room_object']
# Make number of sources required available to algorithms
alg['options']['nSources'] = sim['sources']
if alg['name'].find('ILRMA') == 0 and data_set['name'] == 'Gen Signals':
print('Warning: artificially generated signals are not used with ILRMA, skipping...')
continue
print("\t\t\tSeparation by {} ...".format(alg['name']))
# 5.1 Run given algorithm (online or batch)
if sim['run_type'] == 'online':
unmixed = []
for chunk in mixed_queue:
unmixed_chunk, alg['state'] = alg['func'](chunk, alg['state'], alg.get('options'))
unmixed.append(unmixed_chunk)
# combine all reconstructed chunks
unmixed = np.concatenate(unmixed, axis=1)
elif sim['run_type'] == 'batch':
unmixed, alg['state'] = alg['func'](sim['mixed'], alg['state'], alg.get('options'))
alg['state'] = {}
else:
raise ValueError('unknown run_type={}'.format(sim['run_type']))
unmixed = normalize_rowwise(unmixed)
# play(unmixed[0] * 10000)
alg['unmixed'] = unmixed
# 5.2 Save data to wav files
dir_alg = alg_create_folders(alg, dir_sim_unmixed)
for file_name, s in zip(data_set['file_names'], unmixed):
write("{}/{}_{}.wav".format(dir_alg, data_set['name'], file_name), data_set['fs'], np.float32(s))
# 5.3 Compute metrics
alg['metrics'] = {data_set['name']: evaluate(S, sim['filtered'], unmixed)}
SDR_temp.append(alg['metrics'][data_set['name']]['SDR'])
SIR_temp.append(alg['metrics'][data_set['name']]['SIR'])
SAR_temp.append(alg['metrics'][data_set['name']]['SAR'])
# delete temporary "mixed" array form dict
del sim['mixed']
SDRR.append(SDR_temp)
SARR.append(SAR_temp)
SIRR.append(SIR_temp)
# Create plots for this sim
# plot_sim_data_set_metrics(sim, data_set, dir_sim_plots)
plot_boxes(SDRR, SARR, SIRR, sim['name'], dir_sim_box_plots)
print('\033[35mSimulation \'{}\' finished\033[0m'.format(sim['name']))
print('\033[35mSaving stuff...\033[0m')
# Collect all metrics into new dictionary, display in in console with correct view and plot the results in folder
rew_sims = rework_dict(sims)
print_results(rew_sims)
# plot_metrics(rew_sims, dir_plots)
print('\033[35mAll done.\033[0m')
print(SDRR)
def evaluate(original: np.ndarray, filtered: np.ndarray, unmixed: np.ndarray) -> dict:
ref = np.moveaxis(filtered, 1, 2)
Ns = np.minimum(unmixed.shape[1], ref.shape[1])
# Sn = np.minimum(unmixed.shape[0], ref.shape[0])
SDR, SIR, SAR, P = bss_eval_sources(ref[:, :Ns, 0], unmixed[:, :Ns])
# return {'SDR': SDR, 'SIR': SIR, 'SAR': SAR, 'P': P, 'RMSE': rmse(original, unmixed)}
return {'SDR': np.round(np.mean(SDR), 2),
'SIR': np.round(np.mean(SIR), 2),
'SAR': np.round(np.mean(SAR), 2),
'P': P}
def create_folders() -> Tuple[str, str]:
# Create sim folder
dir_sims = "Sims"
if not os.path.isdir(dir_sims):
os.mkdir(dir_sims)
# Create folder for plots
dir_plots = "{}/plots".format(dir_sims)
if not os.path.isdir(dir_plots):
os.mkdir(dir_plots)
return dir_sims, dir_plots
def sim_create_folders(sim: dict, dir_sims: str) -> Tuple[str, str, str, str, str, str]:
dir_sim = "{}/{}_{}".format(dir_sims, sim['name'], datetime.now().strftime("%Y_%m_%d_%H_%M_%S"))
# dir_sim = "{}/{}".format(dir_sims, sim['name']) # without date - easier for development
if not os.path.isdir(dir_sim):
os.mkdir(dir_sim)
dir_sim_unmixed = "{}/unmixed".format(dir_sim)
if not os.path.isdir(dir_sim_unmixed):
os.mkdir(dir_sim_unmixed)
dir_sim_filtered = "{}/filtered".format(dir_sim)
if not os.path.isdir(dir_sim_filtered):
os.mkdir(dir_sim_filtered)
dir_sim_mixed = "{}/mixed".format(dir_sim)
if not os.path.isdir(dir_sim_mixed):
os.mkdir(dir_sim_mixed)
dir_sim_plots = "{}/plots".format(dir_sim)
if not os.path.isdir(dir_sim_plots):
os.mkdir(dir_sim_plots)
dir_sim_box_plots = "{}/box_plots".format(dir_sim)
if not os.path.isdir(dir_sim_box_plots):
os.mkdir(dir_sim_box_plots)
return dir_sim, dir_sim_filtered, dir_sim_mixed, dir_sim_unmixed, dir_sim_plots, dir_sim_box_plots
def alg_create_folders(alg: dict, dir_sim: str) -> str:
dir_alg = "{}/{}".format(dir_sim, alg['name'])
if not os.path.isdir(dir_alg):
os.mkdir(dir_alg)
return dir_alg
if __name__ == "__main__":
main()