Skip to content

Commit bba32d1

Browse files
committed
[xpu]add script to deal safetensor file for fuse op
1 parent 9424cfa commit bba32d1

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

scripts/conver_safetensor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import paddle
2+
import safetensors.paddle
3+
import numpy as np
4+
5+
new_safetensors = {}
6+
metadata = {"total_size": "16582751232",}
7+
layer11_gate_weight = None
8+
layer11_up_weight = None
9+
for idx in range(1, 6):
10+
file_path = "/path/to/Qwen2-VL-7B-Instruct/model-0000" + str(idx) + "-of-00005.safetensors"
11+
new_file_path="/new_path/to/Qwen2-VL-7B-Instruct-fuse_qkv/model-0000" + str(idx) + "-of-00005.safetensors"
12+
theta = (
13+
safetensors.paddle.load_file(file_path)
14+
)
15+
for key, val in theta.items():
16+
# print("key = ", key, " val.shape = ", val.shape)
17+
if len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.weight':
18+
q_weight = val
19+
k_weight = theta[key.replace('q_proj', 'k_proj')]
20+
v_weight = theta[key.replace('q_proj', 'v_proj')]
21+
qkv_weight = paddle.concat([q_weight, k_weight, v_weight], axis=-1)
22+
# print(qkv_weight.shape)
23+
new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_weight
24+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.bias':
25+
q_bias = val
26+
k_bias = theta[key.replace('q_proj', 'k_proj')]
27+
v_bias = theta[key.replace('q_proj', 'v_proj')]
28+
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
29+
# print(qkv_bias.shape)
30+
new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_bias
31+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.weight':
32+
continue
33+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.bias':
34+
continue
35+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.weight':
36+
continue
37+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.bias':
38+
continue
39+
elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.up_proj.weight':
40+
layer11_up_weight = val
41+
elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.gate_proj.weight':
42+
layer11_gate_weight = val
43+
gate_up_weight = paddle.concat([layer11_gate_weight, layer11_up_weight], axis=-1)
44+
new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight
45+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'gate_proj.weight':
46+
gate_weight = val
47+
up_weight = theta[key.replace('gate_proj', 'up_proj')]
48+
gate_up_weight = paddle.concat([gate_weight, up_weight], axis=-1)
49+
new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight
50+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'up_proj.weight':
51+
continue
52+
else:
53+
new_safetensors[key] = val
54+
# save new safetensors
55+
safetensors.paddle.save_file(new_safetensors, new_file_path, metadata=metadata)
56+
print("save new safetensors for ", new_file_path)
57+
new_safetensors.clear()
58+
59+
60+

0 commit comments

Comments
 (0)