-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for modified sum and difference loss from https://arxiv.org/abs/2208.11428 #71
base: main
Are you sure you want to change the base?
Changes from all commits
206a433
01de139
a4a021f
74930ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -604,3 +604,175 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): | |
return loss | ||
elif self.output == "full": | ||
return loss, sum_loss, diff_loss | ||
|
||
class ModifiedSumAndDifferenceSTFTLoss(torch.nn.Module): | ||
"""Modified Sum and difference stereo STFT loss module. | ||
|
||
See [Martinèz et al., https://arxiv.org/abs/2208.11428) | ||
Empirically found the perceptual-based pre-emphasisfilter as vital when modeling a highly perceptual task suchas music mixing. | ||
Found an easier convergence during training when using a single frame-size loss ratherthan a multi-resolution magnitude loss. | ||
|
||
Args: | ||
fft_size (int): FFT size in samples. | ||
hop_size (int): Hop size of the FFT in samples. | ||
win_length (int): Length of the FFT analysis window. | ||
window (str, optional): Window function type. | ||
w_sum (float, optional): Weight of the sum loss component. Default: 1.0 | ||
w_diff (float, optional): Weight of the difference loss component. Default: 1.0 | ||
reduction (str, optional): Specifies the reduction to apply to the output: | ||
'none': no reduction will be applied, | ||
'mean': the sum of the output will be divided by the number of elements in the output, | ||
'sum': the output will be summed. | ||
Default: 'mean' | ||
sample_rate (float, optional): Audio sample rate. Default: None | ||
output (str, optional): Format of the loss returned. | ||
'loss' : Return only the raw, aggregate loss term. | ||
'full' : Return the raw loss, plus intermediate loss terms. | ||
Default: 'loss' | ||
loss_type (str, optional): Type of loss to compute. Default: 'SClogL1' | ||
'SClogL1' : Sum of Spectral convergence and log magnitude L1 loss on STFT of input and target. | ||
'L2logL1' : Sum of L2 and log magnitude L1 loss on STFT of input and target. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
fft_size: int = 4096, | ||
hop_size: int = 1024, | ||
win_length: int = 4096, | ||
window: str = "hann_window", | ||
w_sum: float = 1.0, | ||
w_diff: float = 1.0, | ||
output: str = "loss", | ||
reduction: str = "mean", | ||
sample_rate: float = None, | ||
loss_type: str = 'SClogL1', | ||
|
||
**kwargs, | ||
): | ||
super().__init__() | ||
self.fft_size = fft_size | ||
self.hop_size = hop_size | ||
self.win_length = win_length | ||
self.window = getattr(torch, window)(win_length) | ||
self.eps = 1e-8 | ||
|
||
self.spectralconv = SpectralConvergenceLoss() | ||
self.l1logstft = STFTMagnitudeLoss( | ||
log=True, | ||
reduction=reduction, | ||
distance='L1', | ||
) | ||
self.l2linstft = STFTMagnitudeLoss( | ||
log=False, | ||
reduction=reduction, | ||
distance='L2', | ||
) | ||
self.sample_rate = sample_rate | ||
|
||
self.awfilter = FIRFilter(filter_type= "aw", fs=sample_rate) | ||
self.lpfilter = FIRFilter(filter_type= "lp", fs=sample_rate) | ||
self.sd = SumAndDifference() | ||
|
||
self.w_sum = w_sum | ||
self.w_diff = w_diff | ||
self.output = output | ||
self.reduction = reduction | ||
self.loss_type = loss_type | ||
|
||
|
||
def stft(self, x): | ||
"""Perform STFT. | ||
Args: | ||
x (Tensor): Input signal tensor (B, T). | ||
|
||
Returns: | ||
Tensor: x_mag, x_phs | ||
Magnitude and phase spectra (B, fft_size // 2 + 1, frames). | ||
""" | ||
x_stft = torch.stft( | ||
x, | ||
self.fft_size, | ||
self.hop_size, | ||
self.win_length, | ||
self.window, | ||
return_complex=True, | ||
) | ||
x_mag = torch.sqrt( | ||
torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) | ||
) | ||
x_phs = torch.angle(x_stft) | ||
return x_mag, x_phs | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor): | ||
"""This loss function assumes batched input of stereo audio in the time domain. | ||
|
||
Args: | ||
input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). | ||
target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). | ||
|
||
Returns: | ||
loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. | ||
loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): | ||
Aggregate and intermediate loss terms. Only returned if output='full'. | ||
""" | ||
assert input.shape == target.shape # must have same shape | ||
bs, chs, seq_len = input.size() | ||
self.awfilter.to(input.device) | ||
self.lpfilter.to(input.device) | ||
self.window = self.window.to(input.device) | ||
|
||
#filter the input and target with A-weighting and lowpass filter | ||
input = input.view(bs * chs, 1, -1) | ||
target = target.view(bs * chs, 1, -1) | ||
input, target = self.awfilter(input, target) | ||
input, target = self.lpfilter(input, target) | ||
|
||
input = input.view(bs, chs, -1) | ||
target = target.view(bs, chs, -1) | ||
|
||
# compute sum and difference signals for both | ||
input_sum, input_diff = self.sd(input) | ||
target_sum, target_diff = self.sd(target) | ||
|
||
#compute magnitude and phase spectra | ||
input_sum_mag, input_sum_phs = self.stft(input_sum.view(-1, input_sum.size(-1))) | ||
input_diff_mag, input_diff_phs = self.stft(input_diff.view(-1, input_diff.size(-1))) | ||
target_sum_mag, target_sum_phs = self.stft(target_sum.view(-1, target_sum.size(-1))) | ||
target_diff_mag, target_diff_phs = self.stft(target_diff.view(-1, target_diff.size(-1))) | ||
|
||
# compute loss terms | ||
|
||
#compute the L1 log magnitude loss | ||
l1log_sum_loss = self.l1logstft(input_sum_mag, target_sum_mag) | ||
l1log_diff_loss = self.l1logstft(input_diff_mag, target_diff_mag) | ||
|
||
if self.loss_type == 'SClogL1': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like the other difference in the distance measure. This should be able to be supported by the main class. However, if it seems easier, we could consider adding a new |
||
#compute the spectral convergence loss | ||
sc_sum_loss = self.spectralconv(input_sum_mag, target_sum_mag) | ||
sc_diff_loss = self.spectralconv(input_diff_mag, target_diff_mag) | ||
#combine the loss terms | ||
|
||
loss = ((self.w_sum * (sc_sum_loss + l1log_sum_loss)) + (self.w_diff * (sc_diff_loss + l1log_diff_loss))) | ||
loss = apply_reduction(loss, reduction=self.reduction) | ||
|
||
if self.output == "loss": | ||
return loss | ||
elif self.output == "full": | ||
return loss, sc_sum_loss, l1log_sum_loss, sc_diff_loss, l1log_diff_loss | ||
|
||
|
||
elif self.loss_type == 'L2logL1': | ||
#compute the L2 linear magnitude loss | ||
l2lin_sum_loss = self.l2linstft(input_sum_mag, target_sum_mag) | ||
l2lin_diff_loss = self.l2linstft(input_diff_mag, target_diff_mag) | ||
#combine the loss terms | ||
loss = ((self.w_sum * (l2lin_sum_loss + l1log_sum_loss)) + (self.w_diff * (l2lin_diff_loss + l1log_diff_loss))) | ||
loss = apply_reduction(loss, reduction=self.reduction) | ||
|
||
if self.output == "loss": | ||
return loss | ||
elif self.output == "full": | ||
return loss, l2lin_sum_loss, l1log_sum_loss, l2lin_diff_loss, l1log_diff_loss | ||
else: | ||
raise ValueError(f"Invalid loss type: '{self.loss_type}'.") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of creating a new class for this we should add parameters to
SumAndDifferenceSTFTLoss
in order to support this behavior. It seems like the major modification is that application of the pre-emphasis filter.