Skip to content

Commit

Permalink
adding updates to README to prepare for new release (0.4.0)
Browse files Browse the repository at this point in the history
  • Loading branch information
csteinmetz1 committed Apr 21, 2023
1 parent 5daecee commit 10b9f7f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 26 deletions.
97 changes: 74 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,47 @@ target = torch.rand(8,1,44100)
loss = mrstft(input, target)
```

**NEW**: Perceptual weighting with mel scaled spectrograms.

```python

bs = 8
chs = 1
seq_len = 131072
sample_rate = 44100

# some audio you want to compare
target = torch.rand(bs, chs, seq_len)
pred = torch.rand(bs, chs, seq_len)

# define the loss function
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
scale="mel",
n_bins=128,
sample_rate=sample_rate,
perceptual_weighting=True,
)

# compute
loss = loss_fn(pred, target)

```

## Citation
If you use this code in your work please consider citing us.
```bibtex
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
year={2020}
}
```


# Loss functions

We categorize the loss functions as either time-domain or frequency-domain approaches.
Expand Down Expand Up @@ -136,9 +177,11 @@ There are some more advanced things you can do based upon the `STFTLoss` class.
For example, you can compute both linear and log scaled STFT errors as in [Engel et al., 2020](https://arxiv.org/abs/2001.04643).
In this case we do not include the spectral convergence term.
```python
stft_loss = auraloss.freq.STFTLoss(w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0, )
stft_loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0,
)
```

There is also a Mel-scaled STFT loss, which has some special requirements.
Expand All @@ -151,30 +194,38 @@ melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily.
Make sure you pass the correct device where the tensors you are comparing will be.
```python
mrmelstft_loss = auraloss.freq.MultiResolutionSTFTLoss(scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda")
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda"
)
```

# Development
If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss.
Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for
further perceptual relevance.

We currently have no tests, but those will also be coming soon, so use caution at the moment.
Future loss functions to be included will target neural network based perceptual losses,
which tend to be a bit more sophisticated than those we have included so far.
```python

If you are interested in adding a loss function please make a pull request.
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)

## Loss functions to be added
- [Spectral Energy Distance](https://arxiv.org/abs/2008.01160)
- [TFGAN Losses](https://arxiv.org/abs/2011.12206)
loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
perceptual_weighting=True,
sample_rate=44100,
scale="mel",
n_bins=128,
)

# Cite
If you use this code in your work please consider citing us.
```
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
year={2020}}
loss = loss_fn(pred, target)
```

# Development

Run tests locally with pytest.

```python -m pytest```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "auraloss"
version = "0.3.0"
version = "0.4.0"
description = "Collection of audio-focused loss functions in PyTorch."
authors = [
{ name = "Christian Steinmetz" },
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
EMAIL = "[email protected]"
AUTHOR = "Christian Steinmetz"
REQUIRES_PYTHON = ">=3.6.0"
VERSION = "0.3.0"
VERSION = "0.4.0"

HERE = Path(__file__).parent

Expand All @@ -32,7 +32,7 @@
url=URL,
packages=["auraloss"],
install_requires=["torch", "numpy"],
extras_require={"test": ["resampy"], "all": ["matplotlib", "librosa", "scipy"]},
extras_require={"test": ["resampy"], "all": ["matplotlib", "librosa", "scipy"]},
include_package_data=True,
license="Apache License 2.0",
classifiers=[
Expand Down

0 comments on commit 10b9f7f

Please sign in to comment.