Skip to content

Commit

Permalink
Merge pull request #68 from simonschwaer/sjs-smoothmss
Browse files Browse the repository at this point in the history
Small changes for flexibility and performance
  • Loading branch information
csteinmetz1 authored Feb 9, 2024
2 parents 1576b0c + d96315e commit 0853e9e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
44 changes: 35 additions & 9 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from typing import List, Any

from .utils import apply_reduction
from .utils import apply_reduction, get_window
from .perceptual import SumAndDifference, FIRFilter


Expand All @@ -25,16 +25,29 @@ class STFTMagnitudeLoss(torch.nn.Module):
See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
compression strength (larger value results in more compression), and `log_eps` can be used
to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
Args:
log (bool, optional): Log-scale the STFT magnitudes,
or use linear scale. Default: True
log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
Default: 0.0
log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
Default: 1.0
distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
reduction (str, optional): Reduction of the loss elements. Default: "mean"
"""

def __init__(self, log=True, distance="L1", reduction="mean"):
def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
super(STFTMagnitudeLoss, self).__init__()

self.log = log
self.log_eps = log_eps
self.log_fac = log_fac

if distance == "L1":
self.distance = torch.nn.L1Loss(reduction=reduction)
elif distance == "L2":
Expand All @@ -44,8 +57,8 @@ def __init__(self, log=True, distance="L1", reduction="mean"):

def forward(self, x_mag, y_mag):
if self.log:
x_mag = torch.log(x_mag)
y_mag = torch.log(y_mag)
x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
return self.distance(x_mag, y_mag)


Expand All @@ -58,8 +71,9 @@ class STFTLoss(torch.nn.Module):
fft_size (int, optional): FFT size in samples. Default: 1024
hop_size (int, optional): Hop size of the FFT in samples. Default: 256
win_length (int, optional): Length of the FFT analysis window. Default: 1024
window (str, optional): Window to apply before FFT, options include:
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
Default: 'hann_window'
w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
Expand Down Expand Up @@ -112,12 +126,13 @@ def __init__(
reduction: str = "mean",
mag_distance: str = "L1",
device: Any = None,
**kwargs
):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.window = get_window(window, win_length)
self.w_sc = w_sc
self.w_log_mag = w_log_mag
self.w_lin_mag = w_lin_mag
Expand All @@ -133,16 +148,20 @@ def __init__(
self.mag_distance = mag_distance
self.device = device

self.phs_used = bool(self.w_phs)

self.spectralconv = SpectralConvergenceLoss()
self.logstft = STFTMagnitudeLoss(
log=True,
reduction=reduction,
distance=mag_distance,
**kwargs
)
self.linstft = STFTMagnitudeLoss(
log=False,
reduction=reduction,
distance=mag_distance,
**kwargs
)

# setup mel filterbank
Expand Down Expand Up @@ -203,7 +222,13 @@ def stft(self, x):
x_mag = torch.sqrt(
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
)
x_phs = torch.angle(x_stft)

# torch.angle is expensive, so it is only evaluated if the values are used in the loss
if self.phs_used:
x_phs = torch.angle(x_stft)
else:
x_phs = None

return x_mag, x_phs

def forward(self, input: torch.Tensor, target: torch.Tensor):
Expand All @@ -224,6 +249,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):

# compute the magnitude and phase spectra of input and target
self.window = self.window.to(input.device)

x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))

Expand All @@ -242,7 +268,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.w_phs else 0.0
phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0

# combine loss terms
loss = (
Expand Down
21 changes: 21 additions & 0 deletions auraloss/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import scipy.signal


def apply_reduction(losses, reduction="none"):
Expand All @@ -8,3 +9,23 @@ def apply_reduction(losses, reduction="none"):
elif reduction == "sum":
losses = losses.sum()
return losses

def get_window(win_type: str, win_length: int):
"""Return a window function.
Args:
win_type (str): Window type. Can either be one of the window function provided in PyTorch
['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
win_length (int): Window length
Returns:
win: The window as a 1D torch tensor
"""

try:
win = getattr(torch, win_type)(win_length)
except:
win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))

return win

0 comments on commit 0853e9e

Please sign in to comment.