-
Notifications
You must be signed in to change notification settings - Fork 44
/
inference.py
86 lines (71 loc) · 3.34 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import torch
import librosa
import numpy as np
from torch import autocast
from contextlib import nullcontext
from models.mn.model import get_model as get_mobilenet
from models.dymn.model import get_model as get_dymn
from models.ensemble import get_ensemble_model
from models.preprocess import AugmentMelSTFT
from helpers.utils import NAME_TO_WIDTH, labels
def audio_tagging(args):
"""
Running Inference on an audio clip.
"""
model_name = args.model_name
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
audio_path = args.audio_path
sample_rate = args.sample_rate
window_size = args.window_size
hop_size = args.hop_size
n_mels = args.n_mels
# load pre-trained model
if len(args.ensemble) > 0:
model = get_ensemble_model(args.ensemble)
else:
if model_name.startswith("dymn"):
model = get_dymn(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name,
strides=args.strides)
else:
model = get_mobilenet(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name,
strides=args.strides, head_type=args.head_type)
model.to(device)
model.eval()
# model to preprocess waveform into mel spectrograms
mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size)
mel.to(device)
mel.eval()
(waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)
waveform = torch.from_numpy(waveform[None, :]).to(device)
# our models are trained in half precision mode (torch.float16)
# run on cuda with torch.float16 to get the best performance
# running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse
with torch.no_grad(), autocast(device_type=device.type) if args.cuda else nullcontext():
spec = mel(waveform)
preds, features = model(spec.unsqueeze(0))
preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy()
sorted_indexes = np.argsort(preds)[::-1]
# Print audio tagging top probabilities
print("************* Acoustic Event Detected: *****************")
for k in range(10):
print('{}: {:.3f}'.format(labels[sorted_indexes[k]],
preds[sorted_indexes[k]]))
print("********************************************************")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
# model name decides, which pre-trained model is loaded
parser.add_argument('--model_name', type=str, default='mn10_as')
parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int)
parser.add_argument('--head_type', type=str, default="mlp")
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--audio_path', type=str, required=True)
# preprocessing
parser.add_argument('--sample_rate', type=int, default=32000)
parser.add_argument('--window_size', type=int, default=800)
parser.add_argument('--hop_size', type=int, default=320)
parser.add_argument('--n_mels', type=int, default=128)
# overwrite 'model_name' by 'ensemble_model' to evaluate an ensemble
parser.add_argument('--ensemble', nargs='+', default=[])
args = parser.parse_args()
audio_tagging(args)