Skip to content

Commit

Permalink
add pre-trained model
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed May 11, 2023
1 parent b078aaa commit d27d832
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 29 deletions.
72 changes: 44 additions & 28 deletions ex_audioset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@
# DDP=2 python ex_audioset.py with trainer.precision=16 models.net.arch=passt_deit_bd_p16_384 -p -m mongodb_server:27000:audioset21_balanced -c "PaSST base 2 GPU"

# capture the config of the trainer with the prefix "trainer", this allows to use sacred to update PL trainer config
get_trainer = ex.command(plTrainer, prefix="trainer") # now you can use in the cmd trainer.precision=16 for example
# now you can use in the cmd trainer.precision=16 for example
get_trainer = ex.command(plTrainer, prefix="trainer")
# capture the WandbLogger and prefix it with "wandb", this allows to use sacred to update WandbLogger config from the command line
get_logger = ex.command(WandbLogger, prefix="wandb")



# define datasets and loaders
get_train_loader = ex.datasets.training.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn), train=True, batch_size=12,
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_full_training_set"),
sampler=CMD("/basedataset.get_ft_weighted_sampler"))
num_workers=16, shuffle=None, dataset=CMD("/basedataset.get_full_training_set"),
sampler=CMD("/basedataset.get_ft_weighted_sampler"))

get_validate_loader = ex.datasets.test.iter(DataLoader, static_args=dict(worker_init_fn=worker_init_fn),
validate=True, batch_size=20, num_workers=16,
Expand All @@ -48,7 +48,7 @@

@ex.config
def default_conf():
cmd = " ".join(sys.argv) # command line arguments
cmd = " ".join(sys.argv) # command line arguments
saque_cmd = os.environ.get("SAQUE_CMD", "").strip()
saque_id = os.environ.get("SAQUE_ID", "").strip()
slurm_job_id = os.environ.get("SLURM_JOB_ID", "").strip()
Expand All @@ -58,7 +58,7 @@ def default_conf():
process_id = os.getpid()
models = {
"net": DynamicIngredient("models.passt.model_ing", arch="passt_deit_bd_p16_384", n_classes=527, s_patchout_t=40,
s_patchout_f=4), # network config
s_patchout_f=4), # network config
"mel": DynamicIngredient("models.preprocess.model_ing",
instance_cmd="AugmentMelSTFT",
n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48,
Expand All @@ -69,9 +69,9 @@ def default_conf():
basedataset = DynamicIngredient("audioset.dataset.dataset", wavmix=1)
wandb = dict(project="passt_audioset", log_model=True)

trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0,
trainer = dict(max_epochs=130, gpus=1, weights_summary='full', benchmark=True, num_sanity_val_steps=0, precision=16,
reload_dataloaders_every_epoch=True)
lr = 0.00002 # learning rate
lr = 0.00002 # learning rate
use_mixup = True
mixup_alpha = 0.3

Expand All @@ -87,7 +87,8 @@ def get_scheduler_lambda(warm_up_len=5, ramp_down_start=50, ramp_down_len=50, la
return exp_warmup_linear_down(warm_up_len, ramp_down_len, ramp_down_start, last_lr_value)
if schedule_mode == "cos_cyc":
return cosine_cycle(warm_up_len, ramp_down_start, last_lr_value)
raise RuntimeError(f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")
raise RuntimeError(
f"schedule_mode={schedule_mode} Unknown for a lambda funtion.")


@ex.command
Expand Down Expand Up @@ -154,18 +155,21 @@ def training_step(self, batch, batch_idx):
if self.use_mixup:
rn_indices, lam = my_mixup(batch_size, self.mixup_alpha)
lam = lam.to(x.device)
x = x * lam.reshape(batch_size, 1, 1, 1) + x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))
x = x * lam.reshape(batch_size, 1, 1, 1) + \
x[rn_indices] * (1. - lam.reshape(batch_size, 1, 1, 1))

y_hat, embed = self.forward(x)

if self.use_mixup:
y_mix = y * lam.reshape(batch_size, 1) + y[rn_indices] * (1. - lam.reshape(batch_size, 1))
y_mix = y * lam.reshape(batch_size, 1) + \
y[rn_indices] * (1. - lam.reshape(batch_size, 1))
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y_mix, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()
else:
samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none")
samples_loss = F.binary_cross_entropy_with_logits(
y_hat, y, reduction="none")
loss = samples_loss.mean()
samples_loss = samples_loss.detach()

Expand Down Expand Up @@ -203,7 +207,8 @@ def validation_step(self, batch, batch_idx):
loss = samples_loss.mean()
out = torch.sigmoid(y_hat.detach())
# self.log("validation.loss", loss, prog_bar=True, on_epoch=True, on_step=False)
results = {**results, net_name + "val_loss": loss, net_name + "out": out, net_name + "target": y.detach()}
results = {**results, net_name + "val_loss": loss,
net_name + "out": out, net_name + "target": y.detach()}
results = {k: v.cpu() for k, v in results.items()}
return results

Expand All @@ -212,24 +217,28 @@ def validation_epoch_end(self, outputs):
if self.do_swa:
model_name = model_name + [("swa_", self.net_swa)]
for net_name, net in model_name:
avg_loss = torch.stack([x[net_name + 'val_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x[net_name + 'val_loss']
for x in outputs]).mean()
out = torch.cat([x[net_name + 'out'] for x in outputs], dim=0)
target = torch.cat([x[net_name + 'target'] for x in outputs], dim=0)
target = torch.cat([x[net_name + 'target']
for x in outputs], dim=0)
try:
average_precision = metrics.average_precision_score(
target.float().numpy(), out.float().numpy(), average=None)
except ValueError:
average_precision = np.array([np.nan] * 527)
try:
roc = metrics.roc_auc_score(target.numpy(), out.numpy(), average=None)
roc = metrics.roc_auc_score(
target.numpy(), out.numpy(), average=None)
except ValueError:
roc = np.array([np.nan] * 527)
logs = {net_name + 'val.loss': torch.as_tensor(avg_loss).cuda(),
net_name + 'ap': torch.as_tensor(average_precision.mean()).cuda(),
net_name + 'roc': torch.as_tensor(roc.mean()).cuda(),
'step': torch.as_tensor(self.current_epoch).cuda()}
torch.save(average_precision, f"ap_perclass_{average_precision.mean()}.pt")
print(average_precision)
# torch.save(average_precision,
# f"ap_perclass_{average_precision.mean()}.pt")
# print(average_precision)
self.log_dict(logs, sync_dist=True)
if self.distributed_mode:
allout = self.all_gather(out)
Expand All @@ -243,7 +252,8 @@ def validation_epoch_end(self, outputs):
'step': torch.as_tensor(self.current_epoch).cuda()}
self.log_dict(logs, sync_dist=False)
else:
self.log_dict({net_name + "allap": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True)
self.log_dict(
{net_name + "allap": logs[net_name + 'ap'], 'step': logs['step']}, sync_dist=True)

def configure_optimizers(self):
# REQUIRED
Expand Down Expand Up @@ -336,9 +346,10 @@ def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100
torch.cuda.synchronize()
t1 = time.time()
for i in range(10):
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
loss = F.binary_cross_entropy_with_logits(
y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Expand All @@ -350,24 +361,26 @@ def model_speed_test(_run, _config, _log, _rnd, _seed, speed_test_batch_size=100
print("testing speed")

for i in range(test_length):
with torch.cuda.amp.autocast():
with torch.cuda.amp.autocast():
y_hat, embed = net(x)
loss = F.binary_cross_entropy_with_logits(y_hat, target, reduction="none").mean()
loss = F.binary_cross_entropy_with_logits(
y_hat, target, reduction="none").mean()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
torch.cuda.synchronize()
t2 = time.time()
print('test done:', (t2 - t1))
print("average speed: ", (test_length * batch_size) / (t2 - t1), " specs/second")
print("average speed: ", (test_length * batch_size) /
(t2 - t1), " specs/second")


@ex.command
def evaluate_only(_run, _config, _log, _rnd, _seed):
# force overriding the config, not logged = not recommended
trainer = get_trainer(logger=get_logger())
val_loader = get_validate_loader()

modul = M(ex)
modul.val_dataloader = None
trainer.val_dataloaders = None
Expand Down Expand Up @@ -412,15 +425,17 @@ def multiprocessing_run(rank, word_size):
print("rank ", rank, os.getpid())
print("word_size ", word_size)
os.environ['NODE_RANK'] = str(rank)
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[rank]
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['CUDA_VISIBLE_DEVICES'].split(",")[
rank]
argv = sys.argv
if rank != 0:
print(f"Unobserved {os.getpid()} with rank {rank}")
argv = argv + ["-u"] # only rank 0 is observed
if "with" not in argv:
argv = argv + ["with"]

argv = argv + [f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
argv = argv + \
[f"trainer.num_nodes={word_size}", f"trainer.accelerator=ddp"]
print(argv)

@ex.main
Expand All @@ -442,7 +457,8 @@ def default_command():
word_size = int(word_size)
print(f"\n\nDDP TRAINING WITH WORD_SIZE={word_size}\n\n")
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}" # plz no collisions
# plz no collisions
os.environ['MASTER_PORT'] = f"{9999 + random.randint(0, 9999)}"
os.environ['PL_IN_DDP_SUBPROCESS'] = '1'

for rank in range(word_size):
Expand Down
21 changes: 20 additions & 1 deletion models/passt.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def _cfg(url='', **kwargs):
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_kd_p16_128_ap486': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_swa_p16_128_ap4761': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s10-ap.4761-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
Expand Down Expand Up @@ -739,6 +743,19 @@ def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
return model


def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
return model


def passt_s_swa_p16_128_ap4761(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
Expand Down Expand Up @@ -902,7 +919,7 @@ def lighten_model(model, cut_depth=0):


@model_ing.command
def get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1, fstride=10,
def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1, fstride=10,
tstride=10,
input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=0, s_patchout_f=0,
):
Expand All @@ -927,6 +944,8 @@ def get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527,
stride = (fstride, tstride)
if arch == "passt_deit_bd_p16_384": # base deit
model_func = deit_base_distilled_patch16_384
elif arch == "passt_s_kd_p16_128_ap486": # pretrained
model_func = passt_s_kd_p16_128_ap486
elif arch == "passt_s_swa_p16_128_ap476": # pretrained
model_func = passt_s_swa_p16_128_ap476
elif arch == "passt_s_swa_p16_128_ap4761":
Expand Down

0 comments on commit d27d832

Please sign in to comment.