Skip to content

Commit

Permalink
add FSD50k
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed Mar 29, 2022
1 parent 48b59ff commit d7049e7
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 7 deletions.
34 changes: 29 additions & 5 deletions config_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,26 @@ def passt():
"net": DynamicIngredient("models.passt.model_ing")
}

@ex.named_config
def passt_s_20sec():
'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 20 seconds'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_20sec_p16_s10_ap474", fstride=10,
tstride=10, input_tdim=2000)
}
basedataset = dict(clip_length=20)

@ex.named_config
def passt_s_30sec():
'use PaSST model pretrained on Audioset (with SWA) ap=476; time encodings for up to 30 seconds'
# python ex_audioset.py evaluate_only with passt_s_ap476
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_s_f128_30sec_p16_s10_ap473", fstride=10,
tstride=10, input_tdim=3000)
}
basedataset = dict(clip_length=20)

@ex.named_config
def passt_s_ap476():
'use PaSST model pretrained on Audioset (with SWA) ap=476'
Expand Down Expand Up @@ -125,9 +145,10 @@ def ensemble_s10():
("passt_s_swa_p16_128_ap476", 10, 10),
("passt_s_swa_p16_128_ap4761", 10, 10),
("passt_s_p16_128_ap472", 10, 10),
]
]
)
}

@ex.named_config
def ensemble_many():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4956'
Expand All @@ -146,9 +167,10 @@ def ensemble_many():
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
("passt_s_p16_s16_128_ap468", 16, 16),
]
]
)
}

@ex.named_config
def ensemble_4():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.4926'
Expand All @@ -162,9 +184,10 @@ def ensemble_4():
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
]
)
}

@ex.named_config
def ensemble_5():
'use ensemble of PaSST models pretrained on Audioset with different strides mAP=.49459'
Expand All @@ -179,9 +202,10 @@ def ensemble_5():
("passt_s_swa_p16_s12_128_ap473", 12, 12),
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
]
)
}

@ex.named_config
def ensemble_s16_14():
'use ensemble of two PaSST models pretrained on Audioset with stride 16 and 14 mAP=.48579'
Expand All @@ -193,7 +217,7 @@ def ensemble_s16_14():
arch_list=[
("passt_s_swa_p16_s14_128_ap471", 14, 14),
("passt_s_swa_p16_s16_128_ap473", 16, 16),
]
]
)
}

Expand Down
1 change: 1 addition & 0 deletions ex_openmic.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def validation_epoch_end(self, outputs):
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
except ValueError:
average_precision = np.array([np.nan] * y_true.shape[1])
#torch.save(average_precision, f"ap_openmic_perclass_{average_precision.mean()}.pt")
try:
roc = np.array([metrics.roc_auc_score(
y_true[:, i], y_pred[:, i], sample_weight=y_mask[:, i]) for i in range(y_true.shape[1])])
Expand Down
40 changes: 40 additions & 0 deletions fsd50k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Experiments on FSD50K
The FSD50K dataset ([Zenodo](https://zenodo.org/record/4060432)) consists of 51K audio clips annotated
with 200 sound event classes taken from the Audioset ontology. The dataset contains 100 hours of audio and is the
second largest publicly available general purpose sound event
recognition dataset after Audioset. Furthermore, the FSD50K
evaluation set is of high quality, with each evaluation label being double-checked and assessed by two to five independent annotators

# Setup
1. Download the dataset from [Zenodo](https://zenodo.org/record/4060432) and unzip it.
2. Convert wav files to mp3s:
```shell
cd fsd50k/prepare_scripts/

python convert_to_mp3.py path/to/fsd50k
```
this will create a folder inside the FSD50K directory with the mp3 files.
3. Pack the mp3 to HDF5 files:
```shell
cd fsd50k/prepare_scripts/
python create_h5pymp3_dataset.py path/to/fsd50k
```
Now you should have inside `../../audioset_hdf5s/mp3/` three new files: `FSD50K.eval_mp3.hdf`, `FSD50K.val_mp3.hdf`, `FSD50K.train_mp3.hdf`.


# Runing Experiments

Similar to the runs on Audioset, PaSST-S:

```shell
# Example call with all the default config:
python ex_fsd50k.py with trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "FSD50k PaSST-S"
```

```shell
# Example call without overlap:
python ex_fsd50k.py with passt_s_swa_p16_s16_128_ap473 models.net.s_patchout_t=10 models.net.s_patchout_f=1 trainer.precision=16 -p -m mongodb_server:27000:audioset21_balanced -c "FSD50k PaSST-S"
```

# Runing Experiments

67 changes: 65 additions & 2 deletions models/passt.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ def _cfg(url='', **kwargs):
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-stfthop160-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt-s-f128-20sec-p16-s10-ap474-swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-20sec-p16-s10-ap.474-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 2000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt-s-f128-30sec-p16-s10-ap473-swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.3-audioset/passt-s-f128-30sec-p16-s10-ap.473-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3000), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'openmic2008_passt_u_f128_p16_s10_ap85_swa': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.4-openmic/openmic2008.passt-u-f128-p16-s10-ap.85-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 3200), crop_pct=1.0,
Expand Down Expand Up @@ -476,9 +484,19 @@ def forward_features(self, x):
# Adding Time/Freq information
if first_RUN: print(" self.time_new_pos_embed.shape", self.time_new_pos_embed.shape)
time_new_pos_embed = self.time_new_pos_embed
if x.shape[-1] != time_new_pos_embed.shape[-1]:
time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
if x.shape[-1] < time_new_pos_embed.shape[-1]:
if self.training:
toffset = torch.randint(1 + time_new_pos_embed.shape[-1] - x.shape[-1], (1,)).item()
if first_RUN: print(f" CUT with randomoffset={toffset} time_new_pos_embed.shape",
time_new_pos_embed.shape)
time_new_pos_embed = time_new_pos_embed[:, :, :, toffset:toffset + x.shape[-1]]
else:
time_new_pos_embed = time_new_pos_embed[:, :, :, :x.shape[-1]]
if first_RUN: print(" CUT time_new_pos_embed.shape", time_new_pos_embed.shape)
else:
warnings.warn(
f"the patches shape:{x.shape} are larger than the expected time encodings {time_new_pos_embed.shape}, x will be cut")
x = x[:, :, :, :time_new_pos_embed.shape[-1]]
x = x + time_new_pos_embed
if first_RUN: print(" self.freq_new_pos_embed.shape", self.freq_new_pos_embed.shape)
x = x + self.freq_new_pos_embed
Expand Down Expand Up @@ -760,6 +778,22 @@ def passt_s_p16_s12_128_ap470(pretrained=False, **kwargs):
return model


def passt_s_f128_20sec_p16_s10_ap474_swa(pretrained=False, **kwargs):
print("\n\n Loading PASST TRAINED ON AUDISET with 20 Second time encodings, with STFT hop of 160 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'passt-s-f128-20sec-p16-s10-ap474-swa', pretrained=pretrained, distilled=True, **model_kwargs)
return model


def passt_s_f128_30sec_p16_s10_ap473_swa(pretrained=False, **kwargs):
print("\n\n Loading PASST TRAINED ON AUDISET with 30 Second time encodings, with STFT hop of 160 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'passt-s-f128-30sec-p16-s10-ap473-swa', pretrained=pretrained, distilled=True, **model_kwargs)
return model


def passt_s_swa_p16_s12_128_ap473(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
Expand Down Expand Up @@ -842,6 +876,30 @@ def fix_embedding_layer(model, embed="default"):
model.patch_embed = PatchEmbedAdaptiveMeanKeepConv(replace=model.patch_embed)
return model

@model_ing.command
def lighten_model(model, cut_depth=0):
if cut_depth == 0:
return model
if cut_depth:
if cut_depth < 0:
print(f"\n Reducing model depth by removing every {-cut_depth} layer \n\n")
else:
print(f"\n Reducing model depth by {cut_depth} \n\n")
if len(model.blocks) < cut_depth + 2:
raise ValueError(f"Cut depth a VIT with {len(model.blocks)} "
f"layers should be between 1 and {len(model.blocks) - 2}")
print(f"\n Before Cutting it was {len(model.blocks)} \n\n")

old_blocks = list(model.blocks.children())
if cut_depth < 0:
print(f"cut_depth={cut_depth}")
old_blocks = [old_blocks[0]] + old_blocks[1:-1:-cut_depth] + [old_blocks[-1]]
else:
old_blocks = [old_blocks[0]] + old_blocks[cut_depth + 1:]
model.blocks = nn.Sequential(*old_blocks)
print(f"\n Atfer Cutting it is {len(model.blocks)} \n\n")
return model


@model_ing.command
def get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1, fstride=10,
Expand Down Expand Up @@ -887,13 +945,18 @@ def get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527,
model_func = passt_s_swa_p16_s12_128_ap473
elif arch == "passt_s_p16_s12_128_ap470":
model_func = passt_s_p16_s12_128_ap470
elif arch == "passt_s_f128_20sec_p16_s10_ap474":
model_func = passt_s_f128_20sec_p16_s10_ap474_swa
elif arch == "passt_s_f128_30sec_p16_s10_ap473":
model_func = passt_s_f128_30sec_p16_s10_ap473_swa

if model_func is None:
raise RuntimeError(f"Unknown model {arch}")
model = model_func(pretrained=pretrained, num_classes=n_classes, in_chans=in_channels,
img_size=input_size, stride=stride, u_patchout=u_patchout,
s_patchout_t=s_patchout_t, s_patchout_f=s_patchout_f)
model = fix_embedding_layer(model)
model = lighten_model(model)
print(model)
return model

Expand Down

0 comments on commit d7049e7

Please sign in to comment.