Skip to content

Commit 5daecee

Browse files
authored
Merge pull request #48 from csteinmetz1/psumdiff
Adding perceptual weighting to `SumAndDifferenceSTFTLoss`
2 parents c4f16e7 + b5bc058 commit 5daecee

File tree

5 files changed

+414
-68
lines changed

5 files changed

+414
-68
lines changed

auraloss/freq.py

+127-67
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22
import numpy as np
3-
from .utils import apply_reduction
3+
from typing import List, Any
44

5+
from .utils import apply_reduction
56
from .perceptual import SumAndDifference, FIRFilter
67

78

@@ -69,6 +70,7 @@ class STFTLoss(torch.nn.Module):
6970
['mel', 'chroma']
7071
Default: None
7172
n_bins (int, optional): Number of scaling frequency bins. Default: None.
73+
perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
7274
scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
7375
eps (float, optional): Small epsilon value for stablity. Default: 1e-8
7476
output (str, optional): Format of the loss returned.
@@ -92,25 +94,26 @@ class STFTLoss(torch.nn.Module):
9294

9395
def __init__(
9496
self,
95-
fft_size=1024,
96-
hop_size=256,
97-
win_length=1024,
98-
window="hann_window",
99-
w_sc=1.0,
100-
w_log_mag=1.0,
101-
w_lin_mag=0.0,
102-
w_phs=0.0,
103-
sample_rate=None,
104-
scale=None,
105-
n_bins=None,
106-
scale_invariance=False,
107-
eps=1e-8,
108-
output="loss",
109-
reduction="mean",
110-
mag_distance="L1",
111-
device=None,
97+
fft_size: int = 1024,
98+
hop_size: int = 256,
99+
win_length: int = 1024,
100+
window: str = "hann_window",
101+
w_sc: float = 1.0,
102+
w_log_mag: float = 1.0,
103+
w_lin_mag: float = 0.0,
104+
w_phs: float = 0.0,
105+
sample_rate: float = None,
106+
scale: str = None,
107+
n_bins: int = None,
108+
perceptual_weighting: bool = False,
109+
scale_invariance: bool = False,
110+
eps: float = 1e-8,
111+
output: str = "loss",
112+
reduction: str = "mean",
113+
mag_distance: str = "L1",
114+
device: Any = None,
112115
):
113-
super(STFTLoss, self).__init__()
116+
super().__init__()
114117
self.fft_size = fft_size
115118
self.hop_size = hop_size
116119
self.win_length = win_length
@@ -122,23 +125,28 @@ def __init__(
122125
self.sample_rate = sample_rate
123126
self.scale = scale
124127
self.n_bins = n_bins
128+
self.perceptual_weighting = perceptual_weighting
125129
self.scale_invariance = scale_invariance
126130
self.eps = eps
127131
self.output = output
128132
self.reduction = reduction
133+
self.mag_distance = mag_distance
129134
self.device = device
130135

131136
self.spectralconv = SpectralConvergenceLoss()
132137
self.logstft = STFTMagnitudeLoss(
133-
log=True, reduction=reduction, distance=mag_distance
138+
log=True,
139+
reduction=reduction,
140+
distance=mag_distance,
134141
)
135142
self.linstft = STFTMagnitudeLoss(
136-
log=False, reduction=reduction, distance=mag_distance
143+
log=False,
144+
reduction=reduction,
145+
distance=mag_distance,
137146
)
138147

139148
# setup mel filterbank
140149
if scale is not None:
141-
142150
try:
143151
import librosa.filters
144152
except Exception as e:
@@ -149,19 +157,32 @@ def __init__(
149157
assert sample_rate != None # Must set sample rate to use mel scale
150158
assert n_bins <= fft_size # Must be more FFT bins than Mel bins
151159
fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
152-
self.fb = torch.tensor(fb).unsqueeze(0)
160+
fb = torch.tensor(fb).unsqueeze(0)
153161

154162
elif self.scale == "chroma":
155163
assert sample_rate != None # Must set sample rate to use chroma scale
156164
assert n_bins <= fft_size # Must be more FFT bins than chroma bins
157-
fb = librosa.filters.chroma(sr=sample_rate, n_fft=fft_size, n_chroma=n_bins)
158-
self.fb = torch.tensor(fb).unsqueeze(0)
165+
fb = librosa.filters.chroma(
166+
sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
167+
)
168+
159169
else:
160-
raise ValueError(f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'.")
170+
raise ValueError(
171+
f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'."
172+
)
173+
174+
self.register_buffer("fb", fb)
161175

162176
if scale is not None and device is not None:
163177
self.fb = self.fb.to(self.device) # move filterbank to device
164178

179+
if self.perceptual_weighting:
180+
if sample_rate is None:
181+
raise ValueError(
182+
f"`sample_rate` must be supplied when `perceptual_weighting = True`."
183+
)
184+
self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
185+
165186
def stft(self, x):
166187
"""Perform STFT.
167188
Args:
@@ -180,25 +201,41 @@ def stft(self, x):
180201
return_complex=True,
181202
)
182203
x_mag = torch.sqrt(
183-
torch.clamp((x_stft.real ** 2) + (x_stft.imag ** 2), min=self.eps)
204+
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
184205
)
185206
x_phs = torch.angle(x_stft)
186207
return x_mag, x_phs
187208

188-
def forward(self, x, y):
209+
def forward(self, input: torch.Tensor, target: torch.Tensor):
210+
bs, chs, seq_len = input.size()
211+
212+
if self.perceptual_weighting: # apply optional A-weighting via FIR filter
213+
# since FIRFilter only support mono audio we will move channels to batch dim
214+
input = input.view(bs * chs, 1, -1)
215+
target = target.view(bs * chs, 1, -1)
216+
217+
# now apply the filter to both
218+
self.prefilter.to(input.device)
219+
input, target = self.prefilter(input, target)
220+
221+
# now move the channels back
222+
input = input.view(bs, chs, -1)
223+
target = target.view(bs, chs, -1)
224+
189225
# compute the magnitude and phase spectra of input and target
190-
self.window = self.window.to(x.device)
191-
x_mag, x_phs = self.stft(x.view(-1, x.size(-1)))
192-
y_mag, y_phs = self.stft(y.view(-1, y.size(-1)))
226+
self.window = self.window.to(input.device)
227+
x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
228+
y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))
193229

194230
# apply relevant transforms
195231
if self.scale is not None:
232+
self.fb = self.fb.to(input.device)
196233
x_mag = torch.matmul(self.fb, x_mag)
197234
y_mag = torch.matmul(self.fb, y_mag)
198235

199236
# normalize scales
200237
if self.scale_invariance:
201-
alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag ** 2).sum([-2, -1]))
238+
alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1]))
202239
y_mag = y_mag * alpha.unsqueeze(-1)
203240

204241
# compute loss terms
@@ -315,21 +352,22 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
315352

316353
def __init__(
317354
self,
318-
fft_sizes=[1024, 2048, 512],
319-
hop_sizes=[120, 240, 50],
320-
win_lengths=[600, 1200, 240],
321-
window="hann_window",
322-
w_sc=1.0,
323-
w_log_mag=1.0,
324-
w_lin_mag=0.0,
325-
w_phs=0.0,
326-
sample_rate=None,
327-
scale=None,
328-
n_bins=None,
329-
scale_invariance=False,
355+
fft_sizes: List[int] = [1024, 2048, 512],
356+
hop_sizes: List[int] = [120, 240, 50],
357+
win_lengths: List[int] = [600, 1200, 240],
358+
window: str = "hann_window",
359+
w_sc: float = 1.0,
360+
w_log_mag: float = 1.0,
361+
w_lin_mag: float = 0.0,
362+
w_phs: float = 0.0,
363+
sample_rate: float = None,
364+
scale: str = None,
365+
n_bins: int = None,
366+
perceptual_weighting: bool = False,
367+
scale_invariance: bool = False,
330368
**kwargs,
331369
):
332-
super(MultiResolutionSTFTLoss, self).__init__()
370+
super().__init__()
333371
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
334372
self.fft_sizes = fft_sizes
335373
self.hop_sizes = hop_sizes
@@ -350,6 +388,7 @@ def __init__(
350388
sample_rate,
351389
scale,
352390
n_bins,
391+
perceptual_weighting,
353392
scale_invariance,
354393
**kwargs,
355394
)
@@ -417,7 +456,7 @@ def __init__(
417456
randomize_rate=1,
418457
**kwargs,
419458
):
420-
super(RandomResolutionSTFTLoss, self).__init__()
459+
super().__init__()
421460
self.resolutions = resolutions
422461
self.min_fft_size = min_fft_size
423462
self.max_fft_size = max_fft_size
@@ -497,45 +536,66 @@ class SumAndDifferenceSTFTLoss(torch.nn.Module):
497536
See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
498537
499538
Args:
500-
fft_sizes (list, optional): List of FFT sizes.
501-
hop_sizes (list, optional): List of hop sizes.
502-
win_lengths (list, optional): List of window lengths.
539+
fft_sizes (List[int]): List of FFT sizes.
540+
hop_sizes (List[int]): List of hop sizes.
541+
win_lengths (List[int]): List of window lengths.
503542
window (str, optional): Window function type.
504543
w_sum (float, optional): Weight of the sum loss component. Default: 1.0
505544
w_diff (float, optional): Weight of the difference loss component. Default: 1.0
545+
perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
546+
mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
547+
n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
548+
sample_rate (float, optional): Audio sample rate. Default: None
506549
output (str, optional): Format of the loss returned.
507550
'loss' : Return only the raw, aggregate loss term.
508551
'full' : Return the raw loss, plus intermediate loss terms.
509552
Default: 'loss'
510-
511-
Returns:
512-
loss:
513-
Aggreate loss term. Only returned if output='loss'.
514-
loss, sum_loss, diff_loss:
515-
Aggregate and intermediate loss terms. Only returned if output='full'.
516553
"""
517554

518555
def __init__(
519556
self,
520-
fft_sizes=[1024, 2048, 512],
521-
hop_sizes=[120, 240, 50],
522-
win_lengths=[600, 1200, 240],
523-
window="hann_window",
524-
w_sum=1.0,
525-
w_diff=1.0,
526-
output="loss",
557+
fft_sizes: List[int],
558+
hop_sizes: List[int],
559+
win_lengths: List[int],
560+
window: str = "hann_window",
561+
w_sum: float = 1.0,
562+
w_diff: float = 1.0,
563+
output: str = "loss",
564+
**kwargs,
527565
):
528-
super(SumAndDifferenceSTFTLoss, self).__init__()
566+
super().__init__()
529567
self.sd = SumAndDifference()
530-
self.w_sum = 1.0
531-
self.w_diff = 1.0
568+
self.w_sum = w_sum
569+
self.w_diff = w_diff
532570
self.output = output
533-
self.mrstft = MultiResolutionSTFTLoss(fft_sizes, hop_sizes, win_lengths, window)
571+
self.mrstft = MultiResolutionSTFTLoss(
572+
fft_sizes,
573+
hop_sizes,
574+
win_lengths,
575+
window,
576+
**kwargs,
577+
)
534578

535-
def forward(self, input, target):
579+
def forward(self, input: torch.Tensor, target: torch.Tensor):
580+
"""This loss function assumes batched input of stereo audio in the time domain.
581+
582+
Args:
583+
input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
584+
target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
585+
586+
Returns:
587+
loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
588+
loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
589+
Aggregate and intermediate loss terms. Only returned if output='full'.
590+
"""
591+
assert input.shape == target.shape # must have same shape
592+
bs, chs, seq_len = input.size()
593+
594+
# compute sum and difference signals for both
536595
input_sum, input_diff = self.sd(input)
537596
target_sum, target_diff = self.sd(target)
538597

598+
# compute error in STFT domain
539599
sum_loss = self.mrstft(input_sum, target_sum)
540600
diff_loss = self.mrstft(input_diff, target_diff)
541601
loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2

tests/__init__.py

Whitespace-only changes.

tests/manual_test_gpu.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
import auraloss
3+
4+
y_hat = torch.randn(2, 1, 131072)
5+
y = torch.randn(2, 1, 131072)
6+
7+
loss_fn = auraloss.freq.MelSTFTLoss(44100)
8+
loss_fn2 = auraloss.freq.MultiResolutionSTFTLoss()
9+
10+
# loss_fn.cuda()
11+
12+
y_hat = y_hat.cuda()
13+
y = y.cuda()
14+
15+
loss = loss_fn2(y_hat, y)
16+
loss = loss_fn(y_hat, y)

0 commit comments

Comments
 (0)