Skip to content

Commit 2d44823

Browse files
committed
BigVGAN-v2 release
1 parent c79aa20 commit 2d44823

38 files changed

+2196
-89
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*.pyc
2+
__pycache__/
3+
*/__pycache__/
4+
alias_free_cuda/build/
5+
exp/
6+
tmp/

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2022 NVIDIA CORPORATION.
3+
Copyright (c) 2024 NVIDIA CORPORATION.
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,32 @@
44
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
55

66

7-
### [Paper](https://arxiv.org/abs/2206.04658)
8-
### [Audio demo](https://bigvgan-demo.github.io/)
7+
### [Paper](https://arxiv.org/abs/2206.04658) &emsp; [Project page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) &emsp; [Audio demo](https://bigvgan-demo.github.io/)
8+
9+
## News
10+
[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
11+
* Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
12+
* Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
13+
* Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
14+
* We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
915

1016
## Installation
11-
Clone the repository and install dependencies.
17+
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
18+
```shell
19+
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
20+
conda activate bigvgan
21+
```
22+
23+
Clone the repository and install dependencies:
1224
```shell
13-
# the codebase has been tested on Python 3.8 / 3.10 with PyTorch 1.12.1 / 1.13 conda binaries
1425
git clone https://github.com/NVIDIA/BigVGAN
26+
cd BigVGAN
1527
pip install -r requirements.txt
1628
```
1729

18-
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset.
30+
31+
32+
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
1933
``` shell
2034
cd LibriTTS && \
2135
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
@@ -29,24 +43,25 @@ cd ..
2943
```
3044

3145
## Training
32-
Train BigVGAN model. Below is an example command for training BigVGAN using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input.
46+
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
3347
```shell
3448
python train.py \
35-
--config configs/bigvgan_24khz_100band.json \
49+
--config configs/bigvgan_v2_24khz_100band_256x.json \
3650
--input_wavs_dir LibriTTS \
3751
--input_training_file LibriTTS/train-full.txt \
3852
--input_validation_file LibriTTS/val-full.txt \
3953
--list_input_unseen_wavs_dir LibriTTS LibriTTS \
4054
--list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \
41-
--checkpoint_path exp/bigvgan
55+
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
4256
```
4357

58+
4459
## Synthesis
4560
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
4661
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
4762
```shell
4863
python inference.py \
49-
--checkpoint_file exp/bigvgan/g_05000000 \
64+
--checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
5065
--input_wavs_dir /path/to/your/input_wav \
5166
--output_dir /path/to/your/output_wav
5267
```
@@ -57,39 +72,98 @@ It loads mel spectrograms from `--input_mels_dir` and saves the generated audio
5772
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
5873
```shell
5974
python inference_e2e.py \
60-
--checkpoint_file exp/bigvgan/g_05000000 \
75+
--checkpoint_file exp/bigvgan_v2_24khz_100band_256x/g_03000000 \
6176
--input_mels_dir /path/to/your/input_mel \
6277
--output_dir /path/to/your/output_wav
6378
```
6479

65-
## Pretrained Models
66-
We provide the [pretrained models](https://drive.google.com/drive/folders/1e9wdM29d-t3EHUpBb8T4dcHrkYGAXTgq).
67-
One can download the checkpoints of generator (e.g., g_05000000) and discriminator (e.g., do_05000000) within the listed folders.
80+
## Using Custom CUDA Kernel for Synthesis
81+
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
82+
83+
```python
84+
generator = BigVGAN(h, use_cuda_kernel=True)
85+
```
86+
87+
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
6888

69-
|Folder Name|Sampling Rate|Mel band|fmax|Params.|Dataset|Fine-Tuned|
70-
|------|---|---|---|---|------|---|
71-
|bigvgan_24khz_100band|24 kHz|100|12000|112M|LibriTTS|No|
72-
|bigvgan_base_24khz_100band|24 kHz|100|12000|14M|LibriTTS|No|
73-
|bigvgan_22khz_80band|22 kHz|80|8000|112M|LibriTTS + VCTK + LJSpeech|No|
74-
|bigvgan_base_22khz_80band|22 kHz|80|8000|14M|LibriTTS + VCTK + LJSpeech|No|
89+
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
7590

76-
The paper results are based on 24kHz BigVGAN models trained on LibriTTS dataset.
91+
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
92+
93+
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
94+
95+
```python
96+
python test_cuda_vs_torch_model.py \
97+
--checkpoint_file /path/to/your/bigvgan/g_03000000
98+
```
99+
100+
```shell
101+
loading plain Pytorch BigVGAN
102+
...
103+
loading CUDA kernel BigVGAN with auto-build
104+
Detected CUDA files, patching ldflags
105+
Emitting ninja build file /path/to/your/BigVGAN/alias_free_cuda/build/build.ninja...
106+
Building extension module anti_alias_activation_cuda...
107+
...
108+
Loading extension module anti_alias_activation_cuda...
109+
...
110+
Loading '/path/to/your/bigvgan/g_03000000'
111+
...
112+
[Success] test CUDA fused vs. plain torch BigVGAN inference
113+
> mean_difference=0.0007238413265440613
114+
...
115+
```
116+
117+
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
118+
119+
120+
## Pretrained Models
121+
We provide the [pretrained models](https://drive.google.com/drive/folders/1L2RDeJMBE7QAI8qV51n0QAf4mkSgUUeE?usp=sharing).
122+
One can download the checkpoints of the generator weight (e.g., `g_(training_steps)`) and its discriminator/optimizer states (e.g., `do_(training_steps)`) within the listed folders.
123+
124+
|Folder Name|Sampling Rate|Mel band|fmax|Upsampling Ratio|Params.|Dataset|Fine-Tuned|
125+
|------|---|---|---|---|---|------|---|
126+
|bigvgan_v2_44khz_128band_512x|44 kHz|128|22050|512|122M|Large-scale Compilation|No|
127+
|bigvgan_v2_44khz_128band_256x|44 kHz|128|22050|256|112M|Large-scale Compilation|No|
128+
|bigvgan_v2_24khz_100band_256x|24 kHz|100|12000|256|112M|Large-scale Compilation|No|
129+
|bigvgan_v2_22khz_80band_256x|22 kHz|80|11025|256|112M|Large-scale Compilation|No|
130+
|bigvgan_v2_22khz_80band_fmax8k_256x|22 kHz|80|8000|256|112M|Large-scale Compilation|No|
131+
|bigvgan_24khz_100band|24 kHz|100|12000|256|112M|LibriTTS|No|
132+
|bigvgan_base_24khz_100band|24 kHz|100|12000|256|14M|LibriTTS|No|
133+
|bigvgan_22khz_80band|22 kHz|80|8000|256|112M|LibriTTS + VCTK + LJSpeech|No|
134+
|bigvgan_base_22khz_80band|22 kHz|80|8000|256|14M|LibriTTS + VCTK + LJSpeech|No|
135+
136+
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
77137
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
78-
Note that, the latest checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality.
138+
Note that the checkpoints use ``snakebeta`` activation with log scale parameterization, which have the best overall quality.
79139

140+
You can fine-tune the models by downloading the checkpoints (both the generator weight and its discrimiantor/optimizer states) and resuming training using your audio dataset.
80141

81-
## TODO
142+
## Training Details of BigVGAN-v2
143+
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
82144

83-
Current codebase only provides a plain PyTorch implementation for the filtered nonlinearity. We are working on a fast CUDA kernel implementation, which will be released in the future.
145+
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
84146

147+
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
148+
149+
## Evaluation Results of BigVGAN-v2
150+
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
151+
152+
|Model|Dataset|Steps|PESQ(↑)|M-STFT(↓)|MCD(↓)|Periodicity(↓)|V/UV F1(↑)|
153+
|-------|-----|-----|-----|-----|-----|-----|-----|
154+
|BigVGAN|LibriTTS|1M|4.027|0.7997|0.3745|0.1018|0.9598|
155+
|BigVGAN|LibriTTS|5M|4.256|0.7409|0.2988|0.0809|0.9698|
156+
|BigVGAN-v2|Large-scale Compilation|3M|**4.359**|**0.7134**|0.3060|**0.0621**|**0.9777**|
157+
158+
## Acknowledgements
159+
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
85160

86161
## References
87162
* [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
88-
89163
* [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
90-
91164
* [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
92-
93165
* [Julius](https://github.com/adefossez/julius) (for low-pass filter)
166+
* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
167+
* [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
168+
* [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
94169

95-
* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)

alias_free_cuda/__init__.py

Whitespace-only changes.

alias_free_cuda/activation1d.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2024 NVIDIA CORPORATION.
2+
# Licensed under the MIT license.
3+
4+
import torch
5+
import torch.nn as nn
6+
from alias_free_torch.resample import UpSample1d, DownSample1d
7+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
8+
from alias_free_cuda import load
9+
load.load()
10+
11+
class FusedAntiAliasActivation(torch.autograd.Function):
12+
"""
13+
Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs
14+
"""
15+
@staticmethod
16+
def forward(ctx, inputs, ftr, alpha, beta):
17+
import anti_alias_activation_cuda
18+
activation_results = anti_alias_activation_cuda.forward(inputs, ftr, alpha, beta)
19+
return activation_results
20+
21+
@staticmethod
22+
def backward(ctx, output_grads):
23+
# TODO: implement bwd pass
24+
raise NotImplementedError
25+
return output_grads, None, None
26+
27+
class Activation1d(nn.Module):
28+
def __init__(self,
29+
activation,
30+
up_ratio: int = 2,
31+
down_ratio: int = 2,
32+
up_kernel_size: int = 12,
33+
down_kernel_size: int = 12,
34+
fused: bool = True
35+
):
36+
super().__init__()
37+
self.up_ratio = up_ratio
38+
self.down_ratio = down_ratio
39+
self.act = activation
40+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
41+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
42+
43+
self.fused = fused # whether to use fused CUDA kernel or not
44+
45+
46+
def forward(self, x):
47+
if not self.fused:
48+
x = self.upsample(x)
49+
x = self.act(x)
50+
x = self.downsample(x)
51+
return x
52+
else:
53+
if self.act.__class__.__name__ == "Snake":
54+
beta = self.act.alpha.data # snake uses same params for alpha and beta
55+
else:
56+
beta = self.act.beta.data # snakebeta uses different params for alpha and beta
57+
alpha = self.act.alpha.data
58+
if not self.act.alpha_logscale: # exp baked into cuda kernel, cancel it out with a log
59+
alpha = torch.log(alpha)
60+
beta = torch.log(beta)
61+
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, alpha, beta)
62+
x = self.downsample(x)
63+
return x
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* coding=utf-8
2+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <cuda_fp16.h>
18+
#include <torch/extension.h>
19+
#include <vector>
20+
21+
namespace anti_alias_activation {
22+
23+
torch::Tensor fwd_cuda(torch::Tensor const& input,
24+
torch::Tensor const& filter,
25+
torch::Tensor const& alpha,
26+
torch::Tensor const& beta
27+
);
28+
29+
torch::Tensor fwd(torch::Tensor const& input,
30+
torch::Tensor const& filter,
31+
torch::Tensor const& alpha,
32+
torch::Tensor const& beta
33+
) {
34+
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
35+
//AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
36+
// (input.scalar_type() == at::ScalarType::BFloat16),
37+
// "Only fp16 and bf16 are supported");
38+
39+
return fwd_cuda(input, filter, alpha, beta);
40+
}
41+
42+
} // end namespace anti_alias_activation
43+
44+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
45+
m.def("forward",
46+
&anti_alias_activation::fwd,
47+
"Anti Alias Activation -- Forward.");
48+
}

0 commit comments

Comments
 (0)