From 10b9f7fd3ed48130645ed84b117bdadbde4b4971 Mon Sep 17 00:00:00 2001 From: "Christian J. Steinmetz" Date: Fri, 21 Apr 2023 09:14:55 +0000 Subject: [PATCH] adding updates to README to prepare for new release (0.4.0) --- README.md | 97 ++++++++++++++++++++++++++++++++++++++------------ pyproject.toml | 2 +- setup.py | 4 +-- 3 files changed, 77 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index fa7df5a..2d0b3c0 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,47 @@ target = torch.rand(8,1,44100) loss = mrstft(input, target) ``` +**NEW**: Perceptual weighting with mel scaled spectrograms. + +```python + +bs = 8 +chs = 1 +seq_len = 131072 +sample_rate = 44100 + +# some audio you want to compare +target = torch.rand(bs, chs, seq_len) +pred = torch.rand(bs, chs, seq_len) + +# define the loss function +loss_fn = auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[1024, 2048, 8192], + hop_sizes=[256, 512, 2048], + win_lengths=[1024, 2048, 8192], + scale="mel", + n_bins=128, + sample_rate=sample_rate, + perceptual_weighting=True, +) + +# compute +loss = loss_fn(pred, target) + +``` + +## Citation +If you use this code in your work please consider citing us. +```bibtex +@inproceedings{steinmetz2020auraloss, + title={auraloss: {A}udio focused loss functions in {PyTorch}}, + author={Steinmetz, Christian J. and Reiss, Joshua D.}, + booktitle={Digital Music Research Network One-day Workshop (DMRN+15)}, + year={2020} +} +``` + + # Loss functions We categorize the loss functions as either time-domain or frequency-domain approaches. @@ -136,9 +177,11 @@ There are some more advanced things you can do based upon the `STFTLoss` class. For example, you can compute both linear and log scaled STFT errors as in [Engel et al., 2020](https://arxiv.org/abs/2001.04643). In this case we do not include the spectral convergence term. ```python -stft_loss = auraloss.freq.STFTLoss(w_log_mag=1.0, - w_lin_mag=1.0, - w_sc=0.0, ) +stft_loss = auraloss.freq.STFTLoss( + w_log_mag=1.0, + w_lin_mag=1.0, + w_sc=0.0, +) ``` There is also a Mel-scaled STFT loss, which has some special requirements. @@ -151,30 +194,38 @@ melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda") You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily. Make sure you pass the correct device where the tensors you are comparing will be. ```python -mrmelstft_loss = auraloss.freq.MultiResolutionSTFTLoss(scale="mel", - n_bins=64, - sample_rate=sample_rate, - device="cuda") +loss_fn = auraloss.freq.MultiResolutionSTFTLoss( + scale="mel", + n_bins=64, + sample_rate=sample_rate, + device="cuda" +) ``` -# Development +If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss. +Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for +further perceptual relevance. -We currently have no tests, but those will also be coming soon, so use caution at the moment. -Future loss functions to be included will target neural network based perceptual losses, -which tend to be a bit more sophisticated than those we have included so far. +```python -If you are interested in adding a loss function please make a pull request. +target = torch.rand(8, 2, 44100) +pred = torch.rand(8, 2, 44100) -## Loss functions to be added -- [Spectral Energy Distance](https://arxiv.org/abs/2008.01160) -- [TFGAN Losses](https://arxiv.org/abs/2011.12206) +loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss( + fft_sizes=[1024, 2048, 8192], + hop_sizes=[256, 512, 2048], + win_lengths=[1024, 2048, 8192], + perceptual_weighting=True, + sample_rate=44100, + scale="mel", + n_bins=128, +) -# Cite -If you use this code in your work please consider citing us. -``` -@inproceedings{steinmetz2020auraloss, - title={auraloss: {A}udio focused loss functions in {PyTorch}}, - author={Steinmetz, Christian J. and Reiss, Joshua D.}, - booktitle={Digital Music Research Network One-day Workshop (DMRN+15)}, - year={2020}} +loss = loss_fn(pred, target) ``` + +# Development + +Run tests locally with pytest. + +```python -m pytest``` diff --git a/pyproject.toml b/pyproject.toml index 09180fd..8a3af74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "auraloss" -version = "0.3.0" +version = "0.4.0" description = "Collection of audio-focused loss functions in PyTorch." authors = [ { name = "Christian Steinmetz" }, diff --git a/setup.py b/setup.py index 3ffb8e8..9c42c4e 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ EMAIL = "c.j.steinmetz@qmul.ac.uk" AUTHOR = "Christian Steinmetz" REQUIRES_PYTHON = ">=3.6.0" -VERSION = "0.3.0" +VERSION = "0.4.0" HERE = Path(__file__).parent @@ -32,7 +32,7 @@ url=URL, packages=["auraloss"], install_requires=["torch", "numpy"], - extras_require={"test": ["resampy"], "all": ["matplotlib", "librosa", "scipy"]}, + extras_require={"test": ["resampy"], "all": ["matplotlib", "librosa", "scipy"]}, include_package_data=True, license="Apache License 2.0", classifiers=[