diff --git a/auraloss/freq.py b/auraloss/freq.py index a5efe70..4393468 100644 --- a/auraloss/freq.py +++ b/auraloss/freq.py @@ -200,9 +200,7 @@ def stft(self, x): self.window, return_complex=True, ) - x_mag = torch.sqrt( - torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) - ) + x_mag = torch.clamp(torch.abs(x_stft), min=self.eps) x_phs = torch.angle(x_stft) return x_mag, x_phs