|
| 1 | +import paddle |
| 2 | +import safetensors.paddle |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +new_safetensors = {} |
| 6 | +metadata = {"total_size": "16582751232",} |
| 7 | +layer11_gate_weight = None |
| 8 | +layer11_up_weight = None |
| 9 | +for idx in range(1, 6): |
| 10 | + file_path = "/path/to/Qwen2-VL-7B-Instruct/model-0000" + str(idx) + "-of-00005.safetensors" |
| 11 | + new_file_path="/new_path/to/Qwen2-VL-7B-Instruct-fuse_qkv/model-0000" + str(idx) + "-of-00005.safetensors" |
| 12 | + theta = ( |
| 13 | + safetensors.paddle.load_file(file_path) |
| 14 | + ) |
| 15 | + for key, val in theta.items(): |
| 16 | + # print("key = ", key, " val.shape = ", val.shape) |
| 17 | + if len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.weight': |
| 18 | + q_weight = val |
| 19 | + k_weight = theta[key.replace('q_proj', 'k_proj')] |
| 20 | + v_weight = theta[key.replace('q_proj', 'v_proj')] |
| 21 | + qkv_weight = paddle.concat([q_weight, k_weight, v_weight], axis=-1) |
| 22 | + # print(qkv_weight.shape) |
| 23 | + new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_weight |
| 24 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'q_proj.bias': |
| 25 | + q_bias = val |
| 26 | + k_bias = theta[key.replace('q_proj', 'k_proj')] |
| 27 | + v_bias = theta[key.replace('q_proj', 'v_proj')] |
| 28 | + qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) |
| 29 | + # print(qkv_bias.shape) |
| 30 | + new_safetensors[key.replace('q_proj', 'qkv_proj')] = qkv_bias |
| 31 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.weight': |
| 32 | + continue |
| 33 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'k_proj.bias': |
| 34 | + continue |
| 35 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.weight': |
| 36 | + continue |
| 37 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'v_proj.bias': |
| 38 | + continue |
| 39 | + elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.up_proj.weight': |
| 40 | + layer11_up_weight = val |
| 41 | + elif len(key.split('.')) == 6 and key.split('.', 2)[2] == '11.mlp.gate_proj.weight': |
| 42 | + layer11_gate_weight = val |
| 43 | + gate_up_weight = paddle.concat([layer11_gate_weight, layer11_up_weight], axis=-1) |
| 44 | + new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight |
| 45 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'gate_proj.weight': |
| 46 | + gate_weight = val |
| 47 | + up_weight = theta[key.replace('gate_proj', 'up_proj')] |
| 48 | + gate_up_weight = paddle.concat([gate_weight, up_weight], axis=-1) |
| 49 | + new_safetensors[key.replace('gate_proj', 'gate_up_fused_proj')] = gate_up_weight |
| 50 | + elif len(key.split('.')) == 6 and key.split('.', 4)[4] == 'up_proj.weight': |
| 51 | + continue |
| 52 | + else: |
| 53 | + new_safetensors[key] = val |
| 54 | + # save new safetensors |
| 55 | + safetensors.paddle.save_file(new_safetensors, new_file_path, metadata=metadata) |
| 56 | + print("save new safetensors for ", new_file_path) |
| 57 | + new_safetensors.clear() |
| 58 | + |
| 59 | + |
| 60 | + |
0 commit comments