1
1
import torch
2
2
import numpy as np
3
- from . utils import apply_reduction
3
+ from typing import List , Any
4
4
5
+ from .utils import apply_reduction
5
6
from .perceptual import SumAndDifference , FIRFilter
6
7
7
8
@@ -69,6 +70,7 @@ class STFTLoss(torch.nn.Module):
69
70
['mel', 'chroma']
70
71
Default: None
71
72
n_bins (int, optional): Number of scaling frequency bins. Default: None.
73
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
72
74
scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
73
75
eps (float, optional): Small epsilon value for stablity. Default: 1e-8
74
76
output (str, optional): Format of the loss returned.
@@ -92,25 +94,26 @@ class STFTLoss(torch.nn.Module):
92
94
93
95
def __init__ (
94
96
self ,
95
- fft_size = 1024 ,
96
- hop_size = 256 ,
97
- win_length = 1024 ,
98
- window = "hann_window" ,
99
- w_sc = 1.0 ,
100
- w_log_mag = 1.0 ,
101
- w_lin_mag = 0.0 ,
102
- w_phs = 0.0 ,
103
- sample_rate = None ,
104
- scale = None ,
105
- n_bins = None ,
106
- scale_invariance = False ,
107
- eps = 1e-8 ,
108
- output = "loss" ,
109
- reduction = "mean" ,
110
- mag_distance = "L1" ,
111
- device = None ,
97
+ fft_size : int = 1024 ,
98
+ hop_size : int = 256 ,
99
+ win_length : int = 1024 ,
100
+ window : str = "hann_window" ,
101
+ w_sc : float = 1.0 ,
102
+ w_log_mag : float = 1.0 ,
103
+ w_lin_mag : float = 0.0 ,
104
+ w_phs : float = 0.0 ,
105
+ sample_rate : float = None ,
106
+ scale : str = None ,
107
+ n_bins : int = None ,
108
+ perceptual_weighting : bool = False ,
109
+ scale_invariance : bool = False ,
110
+ eps : float = 1e-8 ,
111
+ output : str = "loss" ,
112
+ reduction : str = "mean" ,
113
+ mag_distance : str = "L1" ,
114
+ device : Any = None ,
112
115
):
113
- super (STFTLoss , self ).__init__ ()
116
+ super ().__init__ ()
114
117
self .fft_size = fft_size
115
118
self .hop_size = hop_size
116
119
self .win_length = win_length
@@ -122,23 +125,28 @@ def __init__(
122
125
self .sample_rate = sample_rate
123
126
self .scale = scale
124
127
self .n_bins = n_bins
128
+ self .perceptual_weighting = perceptual_weighting
125
129
self .scale_invariance = scale_invariance
126
130
self .eps = eps
127
131
self .output = output
128
132
self .reduction = reduction
133
+ self .mag_distance = mag_distance
129
134
self .device = device
130
135
131
136
self .spectralconv = SpectralConvergenceLoss ()
132
137
self .logstft = STFTMagnitudeLoss (
133
- log = True , reduction = reduction , distance = mag_distance
138
+ log = True ,
139
+ reduction = reduction ,
140
+ distance = mag_distance ,
134
141
)
135
142
self .linstft = STFTMagnitudeLoss (
136
- log = False , reduction = reduction , distance = mag_distance
143
+ log = False ,
144
+ reduction = reduction ,
145
+ distance = mag_distance ,
137
146
)
138
147
139
148
# setup mel filterbank
140
149
if scale is not None :
141
-
142
150
try :
143
151
import librosa .filters
144
152
except Exception as e :
@@ -149,19 +157,32 @@ def __init__(
149
157
assert sample_rate != None # Must set sample rate to use mel scale
150
158
assert n_bins <= fft_size # Must be more FFT bins than Mel bins
151
159
fb = librosa .filters .mel (sr = sample_rate , n_fft = fft_size , n_mels = n_bins )
152
- self . fb = torch .tensor (fb ).unsqueeze (0 )
160
+ fb = torch .tensor (fb ).unsqueeze (0 )
153
161
154
162
elif self .scale == "chroma" :
155
163
assert sample_rate != None # Must set sample rate to use chroma scale
156
164
assert n_bins <= fft_size # Must be more FFT bins than chroma bins
157
- fb = librosa .filters .chroma (sr = sample_rate , n_fft = fft_size , n_chroma = n_bins )
158
- self .fb = torch .tensor (fb ).unsqueeze (0 )
165
+ fb = librosa .filters .chroma (
166
+ sr = sample_rate , n_fft = fft_size , n_chroma = n_bins
167
+ )
168
+
159
169
else :
160
- raise ValueError (f"Invalid scale: { self .scale } . Must be 'mel' or 'chroma'." )
170
+ raise ValueError (
171
+ f"Invalid scale: { self .scale } . Must be 'mel' or 'chroma'."
172
+ )
173
+
174
+ self .register_buffer ("fb" , fb )
161
175
162
176
if scale is not None and device is not None :
163
177
self .fb = self .fb .to (self .device ) # move filterbank to device
164
178
179
+ if self .perceptual_weighting :
180
+ if sample_rate is None :
181
+ raise ValueError (
182
+ f"`sample_rate` must be supplied when `perceptual_weighting = True`."
183
+ )
184
+ self .prefilter = FIRFilter (filter_type = "aw" , fs = sample_rate )
185
+
165
186
def stft (self , x ):
166
187
"""Perform STFT.
167
188
Args:
@@ -180,25 +201,41 @@ def stft(self, x):
180
201
return_complex = True ,
181
202
)
182
203
x_mag = torch .sqrt (
183
- torch .clamp ((x_stft .real ** 2 ) + (x_stft .imag ** 2 ), min = self .eps )
204
+ torch .clamp ((x_stft .real ** 2 ) + (x_stft .imag ** 2 ), min = self .eps )
184
205
)
185
206
x_phs = torch .angle (x_stft )
186
207
return x_mag , x_phs
187
208
188
- def forward (self , x , y ):
209
+ def forward (self , input : torch .Tensor , target : torch .Tensor ):
210
+ bs , chs , seq_len = input .size ()
211
+
212
+ if self .perceptual_weighting : # apply optional A-weighting via FIR filter
213
+ # since FIRFilter only support mono audio we will move channels to batch dim
214
+ input = input .view (bs * chs , 1 , - 1 )
215
+ target = target .view (bs * chs , 1 , - 1 )
216
+
217
+ # now apply the filter to both
218
+ self .prefilter .to (input .device )
219
+ input , target = self .prefilter (input , target )
220
+
221
+ # now move the channels back
222
+ input = input .view (bs , chs , - 1 )
223
+ target = target .view (bs , chs , - 1 )
224
+
189
225
# compute the magnitude and phase spectra of input and target
190
- self .window = self .window .to (x .device )
191
- x_mag , x_phs = self .stft (x .view (- 1 , x .size (- 1 )))
192
- y_mag , y_phs = self .stft (y .view (- 1 , y .size (- 1 )))
226
+ self .window = self .window .to (input .device )
227
+ x_mag , x_phs = self .stft (input .view (- 1 , input .size (- 1 )))
228
+ y_mag , y_phs = self .stft (target .view (- 1 , target .size (- 1 )))
193
229
194
230
# apply relevant transforms
195
231
if self .scale is not None :
232
+ self .fb = self .fb .to (input .device )
196
233
x_mag = torch .matmul (self .fb , x_mag )
197
234
y_mag = torch .matmul (self .fb , y_mag )
198
235
199
236
# normalize scales
200
237
if self .scale_invariance :
201
- alpha = (x_mag * y_mag ).sum ([- 2 , - 1 ]) / ((y_mag ** 2 ).sum ([- 2 , - 1 ]))
238
+ alpha = (x_mag * y_mag ).sum ([- 2 , - 1 ]) / ((y_mag ** 2 ).sum ([- 2 , - 1 ]))
202
239
y_mag = y_mag * alpha .unsqueeze (- 1 )
203
240
204
241
# compute loss terms
@@ -315,21 +352,22 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
315
352
316
353
def __init__ (
317
354
self ,
318
- fft_sizes = [1024 , 2048 , 512 ],
319
- hop_sizes = [120 , 240 , 50 ],
320
- win_lengths = [600 , 1200 , 240 ],
321
- window = "hann_window" ,
322
- w_sc = 1.0 ,
323
- w_log_mag = 1.0 ,
324
- w_lin_mag = 0.0 ,
325
- w_phs = 0.0 ,
326
- sample_rate = None ,
327
- scale = None ,
328
- n_bins = None ,
329
- scale_invariance = False ,
355
+ fft_sizes : List [int ] = [1024 , 2048 , 512 ],
356
+ hop_sizes : List [int ] = [120 , 240 , 50 ],
357
+ win_lengths : List [int ] = [600 , 1200 , 240 ],
358
+ window : str = "hann_window" ,
359
+ w_sc : float = 1.0 ,
360
+ w_log_mag : float = 1.0 ,
361
+ w_lin_mag : float = 0.0 ,
362
+ w_phs : float = 0.0 ,
363
+ sample_rate : float = None ,
364
+ scale : str = None ,
365
+ n_bins : int = None ,
366
+ perceptual_weighting : bool = False ,
367
+ scale_invariance : bool = False ,
330
368
** kwargs ,
331
369
):
332
- super (MultiResolutionSTFTLoss , self ).__init__ ()
370
+ super ().__init__ ()
333
371
assert len (fft_sizes ) == len (hop_sizes ) == len (win_lengths ) # must define all
334
372
self .fft_sizes = fft_sizes
335
373
self .hop_sizes = hop_sizes
@@ -350,6 +388,7 @@ def __init__(
350
388
sample_rate ,
351
389
scale ,
352
390
n_bins ,
391
+ perceptual_weighting ,
353
392
scale_invariance ,
354
393
** kwargs ,
355
394
)
@@ -417,7 +456,7 @@ def __init__(
417
456
randomize_rate = 1 ,
418
457
** kwargs ,
419
458
):
420
- super (RandomResolutionSTFTLoss , self ).__init__ ()
459
+ super ().__init__ ()
421
460
self .resolutions = resolutions
422
461
self .min_fft_size = min_fft_size
423
462
self .max_fft_size = max_fft_size
@@ -497,45 +536,66 @@ class SumAndDifferenceSTFTLoss(torch.nn.Module):
497
536
See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
498
537
499
538
Args:
500
- fft_sizes (list, optional ): List of FFT sizes.
501
- hop_sizes (list, optional ): List of hop sizes.
502
- win_lengths (list, optional ): List of window lengths.
539
+ fft_sizes (List[int] ): List of FFT sizes.
540
+ hop_sizes (List[int] ): List of hop sizes.
541
+ win_lengths (List[int] ): List of window lengths.
503
542
window (str, optional): Window function type.
504
543
w_sum (float, optional): Weight of the sum loss component. Default: 1.0
505
544
w_diff (float, optional): Weight of the difference loss component. Default: 1.0
545
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
546
+ mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
547
+ n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
548
+ sample_rate (float, optional): Audio sample rate. Default: None
506
549
output (str, optional): Format of the loss returned.
507
550
'loss' : Return only the raw, aggregate loss term.
508
551
'full' : Return the raw loss, plus intermediate loss terms.
509
552
Default: 'loss'
510
-
511
- Returns:
512
- loss:
513
- Aggreate loss term. Only returned if output='loss'.
514
- loss, sum_loss, diff_loss:
515
- Aggregate and intermediate loss terms. Only returned if output='full'.
516
553
"""
517
554
518
555
def __init__ (
519
556
self ,
520
- fft_sizes = [1024 , 2048 , 512 ],
521
- hop_sizes = [120 , 240 , 50 ],
522
- win_lengths = [600 , 1200 , 240 ],
523
- window = "hann_window" ,
524
- w_sum = 1.0 ,
525
- w_diff = 1.0 ,
526
- output = "loss" ,
557
+ fft_sizes : List [int ],
558
+ hop_sizes : List [int ],
559
+ win_lengths : List [int ],
560
+ window : str = "hann_window" ,
561
+ w_sum : float = 1.0 ,
562
+ w_diff : float = 1.0 ,
563
+ output : str = "loss" ,
564
+ ** kwargs ,
527
565
):
528
- super (SumAndDifferenceSTFTLoss , self ).__init__ ()
566
+ super ().__init__ ()
529
567
self .sd = SumAndDifference ()
530
- self .w_sum = 1.0
531
- self .w_diff = 1.0
568
+ self .w_sum = w_sum
569
+ self .w_diff = w_diff
532
570
self .output = output
533
- self .mrstft = MultiResolutionSTFTLoss (fft_sizes , hop_sizes , win_lengths , window )
571
+ self .mrstft = MultiResolutionSTFTLoss (
572
+ fft_sizes ,
573
+ hop_sizes ,
574
+ win_lengths ,
575
+ window ,
576
+ ** kwargs ,
577
+ )
534
578
535
- def forward (self , input , target ):
579
+ def forward (self , input : torch .Tensor , target : torch .Tensor ):
580
+ """This loss function assumes batched input of stereo audio in the time domain.
581
+
582
+ Args:
583
+ input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
584
+ target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
585
+
586
+ Returns:
587
+ loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
588
+ loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
589
+ Aggregate and intermediate loss terms. Only returned if output='full'.
590
+ """
591
+ assert input .shape == target .shape # must have same shape
592
+ bs , chs , seq_len = input .size ()
593
+
594
+ # compute sum and difference signals for both
536
595
input_sum , input_diff = self .sd (input )
537
596
target_sum , target_diff = self .sd (target )
538
597
598
+ # compute error in STFT domain
539
599
sum_loss = self .mrstft (input_sum , target_sum )
540
600
diff_loss = self .mrstft (input_diff , target_diff )
541
601
loss = ((self .w_sum * sum_loss ) + (self .w_diff * diff_loss )) / 2
0 commit comments