From 206a43364104b596e5583d9c34181e86c449e75b Mon Sep 17 00:00:00 2001 From: Soumya Sai Vanka <79447355+sai-soum@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:40:42 +0000 Subject: [PATCH 1/4] Update perceptual.py with a first order low pas FIR filter option --- auraloss/perceptual.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index 1cedeb3..9590d1d 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -50,6 +50,7 @@ class FIRFilter(torch.nn.Module): A-weighting filter - "aw" First-order highpass - "hp" Folded differentiator - "fd" + Lowpass filter - "lp" Note that the default coefficeint value of 0.85 is optimized for a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. @@ -69,7 +70,11 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) if ntaps % 2 == 0: raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") - if filter_type == "hp": + if filter_type == "lp": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, coef, 0]).view(1, 1, -1) + elif filter_type == "hp": self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) self.fir.weight.requires_grad = False self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) From 01de139899166cdf6503802972f4bb7973fa9ac1 Mon Sep 17 00:00:00 2001 From: Soumya Sai Vanka <79447355+sai-soum@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:41:48 +0000 Subject: [PATCH 2/4] =?UTF-8?q?Update=20freq.py=20with=20modified=20sum=20?= =?UTF-8?q?and=20difference=20STFT=20loss=20used=20in=20Martin=C3=A8z=20et?= =?UTF-8?q?=20al=20out=20of=20domain=20mixing=20paper?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- auraloss/freq.py | 171 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) diff --git a/auraloss/freq.py b/auraloss/freq.py index a5efe70..4420e87 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -604,3 +604,174 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): return loss elif self.output == "full": return loss, sum_loss, diff_loss + +class ModifiedSumAndDifferenceSTFTLoss(torch.nn.Module): + """Modified Sum and difference stereo STFT loss module. + + See [Martinèz et al., https://arxiv.org/abs/2208.11428) + Empirically found the perceptual-based pre-emphasisfilter as vital when modeling a highly perceptual task suchas music mixing. + Found an easier convergence during training when using a single frame-size loss ratherthan a multi-resolution magnitude loss. + + Args: + fft_size (int): FFT size in samples. + hop_size (int): Hop size of the FFT in samples. + win_length (int): Length of the FFT analysis window. + window (str, optional): Window function type. + w_sum (float, optional): Weight of the sum loss component. Default: 1.0 + w_diff (float, optional): Weight of the difference loss component. Default: 1.0 + reduction (str, optional): Specifies the reduction to apply to the output: + 'none': no reduction will be applied, + 'mean': the sum of the output will be divided by the number of elements in the output, + 'sum': the output will be summed. + Default: 'mean' + sample_rate (float, optional): Audio sample rate. Default: None + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + loss_type (str, optional): Type of loss to compute. Default: 'SClogL1' + 'SClogL1' : Sum of Spectral convergence and log magnitude L1 loss on STFT of input and target. + 'L2logL1' : Sum of L2 and log magnitude L1 loss on STFT of input and target. + """ + + def __init__( + self, + fft_size: int = 4096, + hop_size: int = 1024, + win_length: int = 4096, + window: str = "hann_window", + w_sum: float = 1.0, + w_diff: float = 1.0, + output: str = "loss", + reduction: str = "mean", + sample_rate: float = None, + loss_type: str = 'SClogL1', + + **kwargs, + ): + super().__init__() + self.fft_size = fft_size + self.hop_size = hop_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.eps = 1e-8 + + self.spectralconv = SpectralConvergenceLoss() + self.l1logstft = STFTMagnitudeLoss( + log=True, + reduction=reduction, + distance='L1', + ) + self.l2linstft = STFTMagnitudeLoss( + log=False, + reduction=reduction, + distance='L2', + ) + self.sample_rate = sample_rate + + self.awfilter = FIRFilter(filter_type= "aw", fs=sample_rate) + self.lpfilter = FIRFilter(filter_type= "lp", fs=sample_rate) + self.sd = SumAndDifference() + + self.w_sum = w_sum + self.w_diff = w_diff + self.output = output + self.reduction = reduction + self.loss_type = loss_type + + + def stft(self, x): + """Perform STFT. + Args: + x (Tensor): Input signal tensor (B, T). + + Returns: + Tensor: x_mag, x_phs + Magnitude and phase spectra (B, fft_size // 2 + 1, frames). + """ + x_stft = torch.stft( + x, + self.fft_size, + self.hop_size, + self.win_length, + self.window, + return_complex=True, + ) + x_mag = torch.sqrt( + torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) + ) + x_phs = torch.angle(x_stft) + return x_mag, x_phs + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """This loss function assumes batched input of stereo audio in the time domain. + + Args: + input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). + target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). + + Returns: + loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. + loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + assert input.shape == target.shape # must have same shape + bs, chs, seq_len = input.size() + self.awfilter.to(input.device) + self.lpfilter.to(input.device) + + #filter the input and target with A-weighting and lowpass filter + input = input.view(bs * chs, 1, -1) + target = target.view(bs * chs, 1, -1) + input, target = self.awfilter(input, target) + input, target = self.lpfilter(input, target) + + input = input.view(bs, chs, -1) + target = target.view(bs, chs, -1) + + # compute sum and difference signals for both + input_sum, input_diff = self.sd(input) + target_sum, target_diff = self.sd(target) + + #compute magnitude and phase spectra + input_sum_mag, input_sum_phs = self.stft(input_sum.view(-1, input_sum.size(-1))) + input_diff_mag, input_diff_phs = self.stft(input_diff.view(-1, input_diff.size(-1))) + target_sum_mag, target_sum_phs = self.stft(target_sum.view(-1, target_sum.size(-1))) + target_diff_mag, target_diff_phs = self.stft(target_diff.view(-1, target_diff.size(-1))) + + # compute loss terms + + #compute the L1 log magnitude loss + l1log_sum_loss = self.l1logstft(input_sum_mag, target_sum_mag) + l1log_diff_loss = self.l1logstft(input_diff_mag, target_diff_mag) + + if self.loss_type == 'SClogL1': + #compute the spectral convergence loss + sc_sum_loss = self.spectralconv(input_sum_mag, target_sum_mag) + sc_diff_loss = self.spectralconv(input_diff_mag, target_diff_mag) + #combine the loss terms + + loss = ((self.w_sum * (sc_sum_loss + l1log_sum_loss)) + (self.w_diff * (sc_diff_loss + l1log_diff_loss))) + loss = apply_reduction(loss, reduction=self.reduction) + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sc_sum_loss, l1log_sum_loss, sc_diff_loss, l1log_diff_loss + + + elif self.loss_type == 'L2logL1': + #compute the L2 linear magnitude loss + l2lin_sum_loss = self.l2linstft(input_sum_mag, target_sum_mag) + l2lin_diff_loss = self.l2linstft(input_diff_mag, target_diff_mag) + #combine the loss terms + loss = ((self.w_sum * (l2lin_sum_loss + l1log_sum_loss)) + (self.w_diff * (l2lin_diff_loss + l1log_diff_loss))) + loss = apply_reduction(loss, reduction=self.reduction) + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, l2lin_sum_loss, l1log_sum_loss, l2lin_diff_loss, l1log_diff_loss + else: + raise ValueError(f"Invalid loss type: '{self.loss_type}'.") + From a4a021f9bdf5e963590ba63826c61368eae094fe Mon Sep 17 00:00:00 2001 From: Soumya Sai Vanka <79447355+sai-soum@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:43:29 +0000 Subject: [PATCH 3/4] Update perceptual.py --- auraloss/perceptual.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/auraloss/perceptual.py b/auraloss/perceptual.py index 9590d1d..56f865d 100644 --- a/auraloss/perceptual.py +++ b/auraloss/perceptual.py @@ -70,14 +70,14 @@ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False) if ntaps % 2 == 0: raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") - if filter_type == "lp": + if filter_type == "hp": self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) self.fir.weight.requires_grad = False - self.fir.weight.data = torch.tensor([1, coef, 0]).view(1, 1, -1) - elif filter_type == "hp": + self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) + elif filter_type == "lp": self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) self.fir.weight.requires_grad = False - self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) + self.fir.weight.data = torch.tensor([1, coef, 0]).view(1, 1, -1) elif filter_type == "fd": self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) self.fir.weight.requires_grad = False From 74930ef7c6ee336cfb6f80e573724d610a171119 Mon Sep 17 00:00:00 2001 From: Soumya Sai Vanka <79447355+sai-soum@users.noreply.github.com> Date: Thu, 8 Feb 2024 22:10:04 +0000 Subject: [PATCH 4/4] added window.to(device) in ModifiedSumandDiff loss --- auraloss/freq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/auraloss/freq.py b/auraloss/freq.py index 4420e87..0c03103 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -719,6 +719,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): bs, chs, seq_len = input.size() self.awfilter.to(input.device) self.lpfilter.to(input.device) + self.window = self.window.to(input.device) #filter the input and target with A-weighting and lowpass filter input = input.view(bs * chs, 1, -1)