Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for modified sum and difference loss from https://arxiv.org/abs/2208.11428 #71

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions auraloss/freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,175 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
return loss
elif self.output == "full":
return loss, sum_loss, diff_loss

class ModifiedSumAndDifferenceSTFTLoss(torch.nn.Module):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating a new class for this we should add parameters to SumAndDifferenceSTFTLoss in order to support this behavior. It seems like the major modification is that application of the pre-emphasis filter.

"""Modified Sum and difference stereo STFT loss module.

See [Martinèz et al., https://arxiv.org/abs/2208.11428)
Empirically found the perceptual-based pre-emphasisfilter as vital when modeling a highly perceptual task suchas music mixing.
Found an easier convergence during training when using a single frame-size loss ratherthan a multi-resolution magnitude loss.

Args:
fft_size (int): FFT size in samples.
hop_size (int): Hop size of the FFT in samples.
win_length (int): Length of the FFT analysis window.
window (str, optional): Window function type.
w_sum (float, optional): Weight of the sum loss component. Default: 1.0
w_diff (float, optional): Weight of the difference loss component. Default: 1.0
reduction (str, optional): Specifies the reduction to apply to the output:
'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of elements in the output,
'sum': the output will be summed.
Default: 'mean'
sample_rate (float, optional): Audio sample rate. Default: None
output (str, optional): Format of the loss returned.
'loss' : Return only the raw, aggregate loss term.
'full' : Return the raw loss, plus intermediate loss terms.
Default: 'loss'
loss_type (str, optional): Type of loss to compute. Default: 'SClogL1'
'SClogL1' : Sum of Spectral convergence and log magnitude L1 loss on STFT of input and target.
'L2logL1' : Sum of L2 and log magnitude L1 loss on STFT of input and target.
"""

def __init__(
self,
fft_size: int = 4096,
hop_size: int = 1024,
win_length: int = 4096,
window: str = "hann_window",
w_sum: float = 1.0,
w_diff: float = 1.0,
output: str = "loss",
reduction: str = "mean",
sample_rate: float = None,
loss_type: str = 'SClogL1',

**kwargs,
):
super().__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.eps = 1e-8

self.spectralconv = SpectralConvergenceLoss()
self.l1logstft = STFTMagnitudeLoss(
log=True,
reduction=reduction,
distance='L1',
)
self.l2linstft = STFTMagnitudeLoss(
log=False,
reduction=reduction,
distance='L2',
)
self.sample_rate = sample_rate

self.awfilter = FIRFilter(filter_type= "aw", fs=sample_rate)
self.lpfilter = FIRFilter(filter_type= "lp", fs=sample_rate)
self.sd = SumAndDifference()

self.w_sum = w_sum
self.w_diff = w_diff
self.output = output
self.reduction = reduction
self.loss_type = loss_type


def stft(self, x):
"""Perform STFT.
Args:
x (Tensor): Input signal tensor (B, T).

Returns:
Tensor: x_mag, x_phs
Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
"""
x_stft = torch.stft(
x,
self.fft_size,
self.hop_size,
self.win_length,
self.window,
return_complex=True,
)
x_mag = torch.sqrt(
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
)
x_phs = torch.angle(x_stft)
return x_mag, x_phs

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""This loss function assumes batched input of stereo audio in the time domain.

Args:
input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).

Returns:
loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
Aggregate and intermediate loss terms. Only returned if output='full'.
"""
assert input.shape == target.shape # must have same shape
bs, chs, seq_len = input.size()
self.awfilter.to(input.device)
self.lpfilter.to(input.device)
self.window = self.window.to(input.device)

#filter the input and target with A-weighting and lowpass filter
input = input.view(bs * chs, 1, -1)
target = target.view(bs * chs, 1, -1)
input, target = self.awfilter(input, target)
input, target = self.lpfilter(input, target)

input = input.view(bs, chs, -1)
target = target.view(bs, chs, -1)

# compute sum and difference signals for both
input_sum, input_diff = self.sd(input)
target_sum, target_diff = self.sd(target)

#compute magnitude and phase spectra
input_sum_mag, input_sum_phs = self.stft(input_sum.view(-1, input_sum.size(-1)))
input_diff_mag, input_diff_phs = self.stft(input_diff.view(-1, input_diff.size(-1)))
target_sum_mag, target_sum_phs = self.stft(target_sum.view(-1, target_sum.size(-1)))
target_diff_mag, target_diff_phs = self.stft(target_diff.view(-1, target_diff.size(-1)))

# compute loss terms

#compute the L1 log magnitude loss
l1log_sum_loss = self.l1logstft(input_sum_mag, target_sum_mag)
l1log_diff_loss = self.l1logstft(input_diff_mag, target_diff_mag)

if self.loss_type == 'SClogL1':
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the other difference in the distance measure. This should be able to be supported by the main class. However, if it seems easier, we could consider adding a new ModifiedSumAndDifferenceLoss class but where it inherits from the main class so that we don't get all this repeated code (e.g. stft, etc.)

#compute the spectral convergence loss
sc_sum_loss = self.spectralconv(input_sum_mag, target_sum_mag)
sc_diff_loss = self.spectralconv(input_diff_mag, target_diff_mag)
#combine the loss terms

loss = ((self.w_sum * (sc_sum_loss + l1log_sum_loss)) + (self.w_diff * (sc_diff_loss + l1log_diff_loss)))
loss = apply_reduction(loss, reduction=self.reduction)

if self.output == "loss":
return loss
elif self.output == "full":
return loss, sc_sum_loss, l1log_sum_loss, sc_diff_loss, l1log_diff_loss


elif self.loss_type == 'L2logL1':
#compute the L2 linear magnitude loss
l2lin_sum_loss = self.l2linstft(input_sum_mag, target_sum_mag)
l2lin_diff_loss = self.l2linstft(input_diff_mag, target_diff_mag)
#combine the loss terms
loss = ((self.w_sum * (l2lin_sum_loss + l1log_sum_loss)) + (self.w_diff * (l2lin_diff_loss + l1log_diff_loss)))
loss = apply_reduction(loss, reduction=self.reduction)

if self.output == "loss":
return loss
elif self.output == "full":
return loss, l2lin_sum_loss, l1log_sum_loss, l2lin_diff_loss, l1log_diff_loss
else:
raise ValueError(f"Invalid loss type: '{self.loss_type}'.")

5 changes: 5 additions & 0 deletions auraloss/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class FIRFilter(torch.nn.Module):
A-weighting filter - "aw"
First-order highpass - "hp"
Folded differentiator - "fd"
Lowpass filter - "lp"

Note that the default coefficeint value of 0.85 is optimized for
a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
Expand All @@ -73,6 +74,10 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False)
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
elif filter_type == "lp":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
self.fir.weight.data = torch.tensor([1, coef, 0]).view(1, 1, -1)
elif filter_type == "fd":
self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
self.fir.weight.requires_grad = False
Expand Down