Skip to content

Commit c5ae319

Browse files
committed
[xpu]add script to deal safetensor file for fuse op
1 parent 9424cfa commit c5ae319

File tree

3 files changed

+64
-2
lines changed

3 files changed

+64
-2
lines changed

paddlemix/examples/qwen2_vl/shell/baseline_7b_bs32_1e8.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ TRAINING_MODEL_RESUME="None"
3636
TRAINER_INSTANCES='127.0.0.1'
3737
MASTER='127.0.0.1:8080'
3838

39-
meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json"
39+
# meta_path="paddlemix/examples/qwen2_vl/configs/baseline_6data_330k.json"
40+
meta_path="paddlemix/examples/qwen2_vl/configs/demo_chartqa_500.json"
41+
export PYTHONPATH="/path/to/PaddleMIX/ppdiffusers:/path/to/PaddleNLP:${PYTHONPATH}"
4042

4143
### XPU ###
4244
export XPU_CDNN_CLUSTER_PARALLEL=1

paddlemix/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig
5353

5454
if get_env_device() == "xpu":
55-
from paddle_xpu.layers.linear_utils.Linear import xpu_matmul
55+
from paddle_xpu.layers.nn.linear import xpu_matmul
5656
else:
5757
xpu_matmul = None
5858
try:

scripts/conver_safetensor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import paddle
2+
import safetensors.paddle
3+
import numpy as np
4+
5+
new_safetensors = {}
6+
metadata = {"total_size": "16582751232",}
7+
layer11_gate_weight = None
8+
layer11_up_weight = None
9+
for idx in range(1, 6):
10+
file_path = "/path/to/Qwen2-VL-7B-Instruct/model-0000" + str(idx) + "-of-00005.safetensors"
11+
new_file_path="/new_path/to/Qwen2-VL-7B-Instruct-fuse_qkv/model-0000" + str(idx) + "-of-00005.safetensors"
12+
theta = (
13+
safetensors.paddle.load_file(file_path)
14+
)
15+
for key, val in theta.items():
16+
# print("key = ", key, " val.shape = ", val.shape)
17+
if len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.weight':
18+
q_weight = val
19+
k_weight = theta[key.replace('q_proj', 'k_proj')]
20+
v_weight = theta[key.replace('q_proj', 'v_proj')]
21+
qkv_weight = paddle.concat([q_weight, k_weight, v_weight], axis=-1)
22+
# print(qkv_weight.shape)
23+
new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_weight
24+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.bias':
25+
q_bias = val
26+
k_bias = theta[key.replace('q_proj', 'k_proj')]
27+
v_bias = theta[key.replace('q_proj', 'v_proj')]
28+
qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1)
29+
# print(qkv_bias.shape)
30+
new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_bias
31+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.weight':
32+
continue
33+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.bias':
34+
continue
35+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.weight':
36+
continue
37+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.bias':
38+
continue
39+
elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.up_proj.weight':
40+
layer11_up_weight = val
41+
elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.gate_proj.weight':
42+
layer11_gate_weight = val
43+
gate_up_weight = paddle.concat([layer11_gate_weight, layer11_up_weight], axis=-1)
44+
new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight
45+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'gate_proj.weight':
46+
gate_weight = val
47+
up_weight = theta[key.replace('gate_proj', 'up_proj')]
48+
gate_up_weight = paddle.concat([gate_weight, up_weight], axis=-1)
49+
new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight
50+
elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'up_proj.weight':
51+
continue
52+
else:
53+
new_safetensors[key] = val
54+
# save new safetensors
55+
safetensors.paddle.save_file(new_safetensors, new_file_path, metadata=metadata)
56+
print("save new safetensors for ", new_file_path)
57+
new_safetensors.clear()
58+
59+
60+

0 commit comments

Comments
 (0)