-
Notifications
You must be signed in to change notification settings - Fork 44
/
complexity.py
82 lines (71 loc) · 3.55 KB
/
complexity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import torch
from helpers.flop_count import count_macs, count_macs_transformer
from helpers.peak_memory import peak_memory_mnv3, peak_memory_cnn
from models.mn.model import get_model
from helpers.utils import NAME_TO_WIDTH
from models.preprocess import AugmentMelSTFT
def calc_complexity(args):
# mel
mel = AugmentMelSTFT(n_mels=args.n_mels,
sr=args.resample_rate,
win_length=args.window_size,
hopsize=args.hop_size,
n_fft=args.n_fft
)
# model
if args.model_width:
# manually specified settings
width = args.model_width
model_name = "mn{}".format(str(width).replace(".", ""))
else:
# model width specified via model name
model_name = args.model_name
width = NAME_TO_WIDTH(model_name)
model = get_model(width_mult=width, se_dims=args.se_dims, head_type=args.head_type)
model.eval()
# waveform
waveform = torch.zeros((1, args.resample_rate * 10)) # 10 seconds waveform
spectrogram = mel(waveform)
# squeeze in channel dimension
spectrogram = spectrogram.unsqueeze(1)
if args.complexity_type == "computation":
# use size of spectrogram to calculate multiply-accumulate operations
total_macs = count_macs(model, spectrogram.size())
total_params = sum(p.numel() for p in model.parameters())
print("Model '{}' has {:.2f} million parameters and inference of a single 10-seconds audio clip requires "
"{:.2f} billion multiply-accumulate operations.".format(model_name, total_params/10**6, total_macs/10**9))
elif args.complexity_type == "memory":
if args.memory_efficient_inference:
peak_mem = peak_memory_mnv3(model, spectrogram.size(), args.bits_per_elem)
print("Model '{}' inference (memory efficient) of a single 10-seconds audio clip "
"has a peak memory of {:.2f} kB."
.format(model_name, peak_mem))
else:
peak_mem = peak_memory_cnn(model, spectrogram.size(), args.bits_per_elem)
print("Model '{}' inference of a single 10-seconds audio clip has a peak memory of {:.2f} kB."
.format(model_name, peak_mem))
else:
raise NotImplementedError(f"Unknown complexity type: {args.complexity_type}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
# either computation or memory complexity
parser.add_argument('--complexity_type', type=str, default='computation')
# for memory complexity
parser.add_argument('--memory_efficient_inference', action='store_true', default=False)
# model name decides, which pre-trained model is evaluated in terms of complexity
parser.add_argument('--model_name', type=str, default='mn10_as')
# alternatively, specify model configurations manually
parser.add_argument('--model_width', type=float, default=None)
parser.add_argument('--se_dims', type=str, default='c')
parser.add_argument('--head_type', type=str, default='mlp')
# preprocessing
parser.add_argument('--resample_rate', type=int, default=32000)
parser.add_argument('--window_size', type=int, default=800)
parser.add_argument('--hop_size', type=int, default=320)
parser.add_argument('--n_fft', type=int, default=1024)
parser.add_argument('--n_mels', type=int, default=128)
# memory
parser.add_argument('--bits_per_elem', type=int, default=16)
args = parser.parse_args()
calc_complexity(args)