Skip to content

Commit

Permalink
Add more pre-trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed Oct 28, 2021
1 parent b567ad4 commit ca5eb8c
Show file tree
Hide file tree
Showing 6 changed files with 471 additions and 40 deletions.
119 changes: 111 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,101 @@ conda env export --no-builds | grep -v "prefix" > environment.yml
pip list > pip_list.txt
```

# Training on Audioset
Download and prepare the dataset as explained in the [audioset page](audioset/)
The base PaSST model can be trained for example like this:
```bash
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
# Getting started
Each dataset has an experiment file such as `ex_audioset.py` and `ex_openmic.py` and a dataset folder with a readme file.
In general, you can prob the experiment file for help:
```shell
python ex_audioset.py help
```

you can override any of the configuration using the [sacred syntax](https://sacred.readthedocs.io/en/stable/command_line.html).
In order to see the available options either use [omniboard](https://github.com/vivekratnavel/omniboard) or use:
```shell
python ex_audioset.py print_config
```
There are many pre-defined configuration options in `config_updates.py`. These include different models, setups etc...
You can list these configurations with:
```shell
python ex_audioset.py print_named_configs
```
The overall configurations looks like this:
```yaml
...
seed = 542198583 # the random seed for this experiment
slurm_job_id = ''
speed_test_batch_size = 100
swa = True
swa_epoch_start = 50
swa_freq = 5
use_mixup = True
warm_up_len = 5
weight_decay = 0.0001
basedataset:
base_dir = 'audioset_hdf5s/' # base directory of the dataset, change it or make a link
eval_hdf5 = 'audioset_hdf5s/mp3/eval_segments_mp3.hdf'
wavmix = 1
....
roll_conf:
axis = 1
shift = None
shift_range = 50
datasets:
test:
batch_size = 20
dataset = {CMD!}'/basedataset.get_test_set'
num_workers = 16
validate = True
training:
batch_size = 12
dataset = {CMD!}'/basedataset.get_full_training_set'
num_workers = 16
sampler = {CMD!}'/basedataset.get_ft_weighted_sampler'
shuffle = None
train = True
models:
mel:
freqm = 48
timem = 192
hopsize = 320
htk = False
n_fft = 1024
n_mels = 128
norm = 1
sr = 32000
...
net:
arch = 'passt_s_swa_p16_128_ap476'
fstride = 10
in_channels = 1
input_fdim = 128
input_tdim = 998
n_classes = 527
s_patchout_f = 4
s_patchout_t = 40
tstride = 10
u_patchout = 0
...
trainer:
accelerator = None
accumulate_grad_batches = 1
amp_backend = 'native'
amp_level = 'O2'
auto_lr_find = False
auto_scale_batch_size = False
...
```
There are many things that can be updated from the command line.
In short:
- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api).
- `models.net` are the passt options.
- `models.mel` are the preprocessing options.
- All the configuration options under `trainer` are pytorch lightning trainer [api](https://pytorch-lightning.readthedocs.io/en/1.4.1/common/trainer.html#trainer-class-api). For example, to turn off cuda benchmarking add `trainer.benchmark=False` to the command line.
- `models.net` are the PaSST (or the chosen NN) options.
- `models.mel` are the preprocessing options (mel spectrograms).

# Training on Audioset
Download and prepare the dataset as explained in the [audioset page](audioset/)
The base PaSST model can be trained for example like this:
```bash
python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base"
```

For example using only unstructured patchout of 400:
```bash
Expand All @@ -68,6 +148,7 @@ Multi-gpu training can be enabled by setting the environment variable `DDP`, for
DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"
```


# Pre-trained models
Please check the [releases page](releases/), to download pre-trained models.
In general, you can get a pretrained model on Audioset using
Expand All @@ -79,6 +160,28 @@ model = get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=
```
this will get automatically download pretrained PaSST on audioset with with mAP of ```0.476```. the model was trained with ```s_patchout_t=40, s_patchout_f=4``` but you can change these to better fit your task/ computational needs.

There are several pretrained models availble with different strides (overlap) and with/without using SWA: `passt_s_p16_s16_128_ap468, passt_s_swa_p16_s16_128_ap473, passt_s_swa_p16_s14_128_ap471, passt_s_p16_s14_128_ap469, passt_s_swa_p16_s12_128_ap473, passt_s_p16_s12_128_ap470`.
For example, In `passt_s_swa_p16_s16_128_ap473`: `p16` mean patch size is `16x16`, `s16` means no overlap (stride=16), 128 mel bands, `ap473` refers to the performance of this model on Audioset mAP=0.479.

In general, you can get a this pretrained model using:
```python
from models.passt import get_model
passt = get_model(arch="passt_s_swa_p16_s16_128_ap473", fstride=16, tstride=16)
```
Using the framework, you can evaluate this model using:
```shell
python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
```

Two ensemble of these models are provided as well:
A large ensemble giving `mAP=.4956`
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
```
An ensemble of models with `stride=10` giving `mAP=.4864`
```shell
python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10
```

# Contact
The repo will be updated, in the mean time if you have any questions or problems feel free to open an issue on GitHub, or contact the authors directly.
6 changes: 3 additions & 3 deletions audioset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def roll_func(b):


@dataset.command
def get_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192,mel_bins=128):
def get_training_set(normalize, roll, wavmix=False):
ds = get_base_training_set()
get_ir_sample()
if normalize:
Expand All @@ -346,7 +346,7 @@ def get_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192,mel_bin


@dataset.command
def get_full_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192, mel_bins=128):
def get_full_training_set(normalize, roll, wavmix=False):
ds = get_base_full_training_set()
get_ir_sample()
if normalize:
Expand All @@ -362,7 +362,7 @@ def get_full_training_set(normalize, roll, wavmix=False, freqm=48, timem= 192, m


@dataset.command
def get_test_set(normalize, roll, mel_bins=128):
def get_test_set(normalize):
ds = get_base_test_set()
if normalize:
print("normalized test!")
Expand Down
157 changes: 152 additions & 5 deletions config_updates.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,178 @@
from sacred.config_helpers import DynamicIngredient, CMD


def add_configs(ex):
'''
This functions add generic configuration for the experiments, such as mix-up, architectures, etc...
@param ex: Ba3l Experiment
@return:
'''

@ex.named_config
def nomixup():
'Don\'t apply mix-up (spectrogram level).'
use_mixup = False
mixup_alpha = 0.3

@ex.named_config
def mixup():
' Apply mix-up (spectrogram level).'
use_mixup = True
mixup_alpha = 0.3

@ex.named_config
def mini_train():
# just to debug
'limit training/validation to 5 batches for debbuging.'
trainer = dict(limit_train_batches=5, limit_val_batches=5)


@ex.named_config
def passt():
'use PaSST model'
models = {
"net": DynamicIngredient("models.passt.model_ing")
}

@ex.named_config
def passt_s_ap476():
'use PaSST model pretrained on Audioset (with SWA) ap=476'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.vit.passt.model_ing")
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap476", fstride=10,
tstride=10)
}

@ex.named_config
def passt_s_ap4763():
'use PaSST model pretrained on Audioset (with SWA) ap=4763'
# test with: python ex_audioset.py evaluate_only with passt_s_ap4763
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_128_ap4763", fstride=10,
tstride=10)
}

@ex.named_config
def passt_s_ap472():
'use PaSST model pretrained on Audioset (no SWA) ap=472'
# test with: python ex_audioset.py evaluate_only with passt_s_ap472
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_128_ap472", fstride=10,
tstride=10)
}

@ex.named_config
def passt_s_p16_s16_128_ap468():
'use PaSST model pretrained on Audioset (no SWA) ap=468 NO overlap'
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s16_128_ap468
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s16_128_ap468", fstride=16,
tstride=16)
}

@ex.named_config
def passt_s_swa_p16_s16_128_ap473():
'use PaSST model pretrained on Audioset (SWA) ap=473 NO overlap'
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s16_128_ap473
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s16_128_ap473", fstride=16,
tstride=16)
}

@ex.named_config
def passt_s_swa_p16_s14_128_ap471():
'use PaSST model pretrained on Audioset stride=14 (SWA) ap=471 '
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s14_128_ap471
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s14_128_ap471", fstride=14,
tstride=14)
}

@ex.named_config
def passt_s_p16_s14_128_ap469():
'use PaSST model pretrained on Audioset stride=14 (No SWA) ap=469 '
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s14_128_ap469
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s14_128_ap469", fstride=14,
tstride=14)
}

@ex.named_config
def passt_s_swa_p16_s12_128_ap473():
'use PaSST model pretrained on Audioset stride=12 (SWA) ap=473 '
# test with: python ex_audioset.py evaluate_only with passt_s_swa_p16_s12_128_ap473
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_swa_p16_s12_128_ap473", fstride=12,
tstride=12)
}

@ex.named_config
def passt_s_p16_s12_128_ap470():
'use PaSST model pretrained on Audioset stride=12 (No SWA) ap=4670 '
# test with: python ex_audioset.py evaluate_only with passt_s_p16_s12_128_ap470
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_p16_s12_128_ap470", fstride=12,
tstride=12)
}

@ex.named_config
def ensemble_s10():
'use ensemble of PaSST models pretrained on Audioset with S10 mAP=.4864'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s10
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s10", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
]
)
}
@ex.named_config
def ensemble_many():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4956'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_many
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_many", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
("passt_s_p16_s12_128_ap470", 12, 12),
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_p16_s14_128_ap469", 14, 14),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
("passt_s_p16_s16_128_ap468", 16, 16),
]
)
}
@ex.named_config
def ensemble_s16_14():
'use ensemble of PaSST models pretrained on Audioset with stride 16 and 14 mAP=.4863'
# test with: python ex_audioset.py evaluate_only with trainer.precision=16 ensemble_s16_14
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="ensemble_s16", fstride=None,
tstride=None, instance_cmd="get_ensemble_model",
# don't call get_model but rather get_ensemble_model
arch_list=[
("passt_s_p16_s14_128_ap469", 14, 14),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
("passt_s_p16_s16_128_ap468", 16, 16),
]
)
}

@ex.named_config
def dynamic_roll():
# dynamically roll the spectrograms/waveforms
# updates the dataset config
basedataset = dict(roll=True, roll_conf=dict(axis=1, shift_range=10000)
)


# extra commands

@ex.command
Expand All @@ -46,4 +194,3 @@ def test_loaders_train_speed():
print(f"{i}/{len(itr)}", end="\r")
end = time.time()
print("totoal time:", end - start)

Loading

0 comments on commit ca5eb8c

Please sign in to comment.