-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert_decoder256.py
75 lines (64 loc) · 2.66 KB
/
convert_decoder256.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import whisper
import torch
import coremltools as ct
import os
import sys
import numpy as np
from timeit import default_timer as timer
print("----------------")
print("🐳 Decoder256 🐳")
print("----------------")
# model setting
modelName = sys.argv[1] if len(sys.argv) > 1 else "small"
model = whisper.load_model(modelName).cpu()
modelSize = modelName.split(".")[0]
n_state = { 'tiny': 384, 'base': 512, 'small': 768, 'medium': 1024, 'large': 1280}[modelSize]
n_layer = { 'tiny': 4, 'base': 6, 'small': 12, 'medium': 24, 'large': 32}[modelSize]
n_head = n_state//64
decoder = model.decoder
decoder.eval()
inType=np.float16
outType=np.float16
bs = 1 # beam_size
# max token len for first time = max_prefix_len(224) + sot_len(3)
max_n_ctx = decoder.max_n_ctx_for_1st
x = torch.ones((bs, max_n_ctx, n_state))
qk_mask = torch.zeros((max_n_ctx, max_n_ctx))
cross_k_caches = torch.ones((n_layer, n_head, 64, 1500))
cross_v_caches = torch.ones((n_layer, n_head, 1500, 64))
import warnings
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
traced_decoder = torch.jit.trace_module(decoder,
{'forwardBlocks': {"x":x,
"qk_mask": qk_mask,
"cross_k_caches": cross_k_caches,
"cross_v_caches": cross_v_caches}
},
example_inputs_is_kwarg=True)
# ct.convert only look forward func
traced_decoder.forward = traced_decoder.forwardBlocks
# input types for convert
input1 = ct.TensorType("x", x.shape, dtype=inType)
input2 = ct.TensorType("qk_mask", qk_mask.shape, dtype=inType)
input3 = ct.TensorType("cross_k_caches", cross_k_caches.shape, dtype=inType)
input4 = ct.TensorType("cross_v_caches", cross_v_caches.shape, dtype=inType)
inputs = [input1, input2, input3, input4]
outputs = [ct.TensorType("out_x", dtype=outType),
ct.TensorType("out_cross_head_weights", dtype=outType),
ct.TensorType("out_new_masked_kv_caches", dtype=outType)]
startT = timer()
decoder = ct.convert(
traced_decoder,
convert_to="mlprogram",
inputs=inputs,
outputs=outputs,
compute_units=ct.ComputeUnit.CPU_AND_NE,
minimum_deployment_target=ct.target.iOS16, # make fp16 input and output available
skip_model_load=True,
)
print(f"{modelName} decoder256 conversion time: {timer()-startT:.3f}s")
folder_path = f"coreml/{modelName}"
if not os.path.exists(folder_path):
os.mkdir(folder_path)
decoder.save(f"{folder_path}/Decoder256.mlpackage")