forked from ws-choi/AMSS-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
task2_eval.py
213 lines (160 loc) · 8.09 KB
/
task2_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from argparse import ArgumentParser
from pathlib import Path
import librosa
import museval
import numpy as np
import torch
import wandb
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.metrics.pairwise import paired_distances
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.amss import model_definition
from src.data.musdb_amss_dataset.amss_task2_datasets import task2_config
from src.data.musdb_amss_dataset.musdb_amss_definitions import musdb_amss_config
from src.data.musdb_wrapper import MusdbUnmixedTestSet, SingleTrackSet_for_Task2
from src.utils.eval_metric_marco import getMSE_MFCC_mc
from src.utils.functions import load_hparams_from_yaml
def get_unmixed_testset(args, musdb_root=None):
new_args = {key: args[key] for key in args.keys() if key in
['musdb_root', 'n_fft', 'hop_length', 'num_frame']}
if musdb_root is not None:
new_args['musdb_root'] = musdb_root
return MusdbUnmixedTestSet(**new_args)
def getMFCC(x, sr=44100, mels=128, mfcc=13, mean_norm=False):
return librosa.feature.mfcc(x,sr=sr, n_mfcc=mfcc, n_mels=mels, n_fft=4096, hop_length=2048)
def eval_amss_track(model, track, amss, batch_size, cuda, hop_length, num_frame, word_to_idx):
result_dict = {}
track_length = track.shape[1]
# 1: Evaluate amss for the i-th track
_, before, after = amss.edit(track)
result_dict['amss'] = after
# 2: Estimate Manipulated Track
with torch.no_grad():
dataset_before = SingleTrackSet_for_Task2(before, hop_length, num_frame, amss, word_to_idx)
dataloader = DataLoader(dataset_before, batch_size, shuffle=False)
trim_length = dataset_before.trim_length
manipulated_hat = []
for before, desc_amss, _, _ in dataloader:
if cuda:
before, desc_amss = before.cuda(), desc_amss.cuda()
afters = model.manipulate(before, desc_amss, token_lengths=[len(d) for d in desc_amss])
if cuda:
afters = afters.cpu()
manipulated_hat.append(afters.detach().numpy())
manipulated_trim = np.vstack(manipulated_hat)[:, trim_length:-trim_length]
manipulated_trim = manipulated_trim.reshape(-1, 2)[:track_length]
result_dict['amss_hat'] = manipulated_trim
return result_dict
def eval(ckpt_root, run_id, config_path, ckpt_path, musdb_root=None, batch_size=4, cuda=True, logger='wandb'):
lr_mode = False
ckpt_root = Path(ckpt_root)
run_id = ckpt_root.joinpath(run_id)
config_path = run_id.joinpath(config_path)
ckpt_path = run_id.joinpath(ckpt_path)
# Define Model
config = load_hparams_from_yaml(config_path)
args = {key: config[key]['value'] for key in config.keys() if isinstance(config[key], dict)}
model = model_definition.get_class_by_name(args['model'])
model = model(**args)
model = model.load_from_checkpoint(ckpt_path)
# load word dictionary
word_to_idx = musdb_amss_config.word_to_idx
# Load related stft config
hop_length = args['hop_length']
num_frame = args['num_frame']
test_unmixed = get_unmixed_testset(args, musdb_root)
model = model.eval()
model = model.cuda() if cuda else model
if logger == 'wandb':
project = 'task2_eval_dev' if 'dev' in musdb_root else 'task2_eval'
wandb_logger = wandb.init(job_type='eval', config=args, project=project, tags=[args['model']],
name='{}_{}'.format(args['model'], ckpt_path))
else:
wandb_logger = None
for amss in task2_config.evaluation_amss_set:
desc = amss.gen_desc_default()
skip_keyword = ['separate', 'mute']
skip = False
for keyword in skip_keyword:
if keyword in desc:
skip = True
break
if skip:
continue
# clear_output()
print(amss)
a_prime_results = []
tar_results = []
acc_results = []
# For the i-th track!
for track_idx in tqdm(range(test_unmixed.num_tracks)):
track = test_unmixed[track_idx]
result_dict = eval_amss_track(model, track, amss, batch_size, cuda, hop_length, num_frame, word_to_idx)
a_prime_result = multi_channel_dist(result_dict['amss'], result_dict['amss_hat'])
# tar_result = multi_channel_dist(result_dict['tar'], result_dict['tar_hat'])
# acc_result = multi_channel_dist(result_dict['acc'], result_dict['acc_hat'])
if logger == 'wandb':
for key in a_prime_result.keys():
if ('left' in key or 'right' in key) and not lr_mode:
continue
wandb_logger.log({'a_prime/{}_{}'.format(desc, key): a_prime_result[key]})
start = result_dict['amss_hat'].shape[0] // 2
if 'dev' in project:
wandb_logger.log({'result_sample_{}_{}'.format(track_idx, amss): [
wandb.Audio(result_dict['amss_hat'][start:start + 44100 * 2],
caption='{}_{}'.format(track_idx, amss), sample_rate=44100)]})
else:
for result in [a_prime_result]: #, tar_result, acc_result]:
print(result)
a_prime_results.append(
np.array([a_prime_result['mae'],
a_prime_result['mae_left'],
a_prime_result['mae_right'],
a_prime_result['mfcc_rmse'],
a_prime_result['mfcc_rmse_left'],
a_prime_result['mfcc_rmse_right']])
)
if logger == 'wandb':
scores = np.mean(np.stack(a_prime_results), axis=0)
wandb_logger.log({'agg_mid/a_prime_mae_{}'.format(desc): scores[0]})
wandb_logger.log({'agg_left/a_prime_mae_{}'.format(desc): scores[1]})
wandb_logger.log({'agg_right/a_prime_mae_{}'.format(desc): scores[2]})
wandb_logger.log({'agg_mid/a_prime_mfccrmse_{}'.format(desc): scores[3]})
wandb_logger.log({'agg_left/a_prime_mfccrmse_{}'.format(desc): scores[4]})
wandb_logger.log({'agg_right/a_prime_mfccrmse_{}'.format(desc): scores[5]})
else:
scores = np.mean(np.stack(a_prime_results), axis=0)
print({'agg_mid/a_prime_mae_{}'.format(desc): scores[0]})
print({'agg_left/a_prime_mae_{}'.format(desc): scores[1]})
print({'agg_right/a_prime_mae_{}'.format(desc): scores[2]})
print({'agg_mid/a_prime_mfccrmse_{}'.format(desc): scores[3]})
print({'agg_left/a_prime_mfccrmse_{}'.format(desc): scores[4]})
print({'agg_right/a_prime_mfccrmse_{}'.format(desc): scores[5]})
if logger == 'wandb':
wandb_logger.finish()
def multi_channel_dist(x, x_hat):
left = x[:, 0]
left_hat = x_hat[:, 0]
right = x[:, 1]
right_hat = x_hat[:, 1]
metric_dict = {'mae': mean_absolute_error(x, x_hat),
'mae_left': mean_absolute_error(left, left_hat),
'mae_right': mean_absolute_error(right, right_hat)}
left, left_hat, right, right_hat = [getMFCC(wave) for wave in [left, left_hat, right, right_hat]]
metric_dict['mfcc_rmse'] = mean_squared_error(np.concatenate([left, right]), np.concatenate([left_hat, right_hat]), squared=False)
metric_dict['mfcc_rmse_left'] = mean_squared_error(left, left_hat, squared=False)
metric_dict['mfcc_rmse_right'] = mean_squared_error(right, right_hat, squared=False)
return metric_dict
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--ckpt_root', type=str, default='etc/checkpoints/')
parser.add_argument('--run_id', type=str)
parser.add_argument('--config_path', type=str, default='config.yaml')
parser.add_argument('--ckpt_path', type=str)
parser.add_argument('--musdb_root', type=str, default='../repos/musdb18_wav')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--cuda', type=bool, default=False)
parser.add_argument('--logger', type=str, default=None)
namespace = parser.parse_args()
eval(**vars(namespace))