From 7b4260879ccbd12131928c3ebc48ec1fb36b25ee Mon Sep 17 00:00:00 2001 From: zhaohb Date: Fri, 24 Nov 2023 09:47:33 +0800 Subject: [PATCH] fix bug --- qwen/weight.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/qwen/weight.py b/qwen/weight.py index 7a9dc752..6df0264d 100644 --- a/qwen/weight.py +++ b/qwen/weight.py @@ -823,7 +823,7 @@ def preprocess_groupwise_weight_params( split_qkv_suf.append(split_qkv) idx = layer - mapping.pp_rank * layers_per_pipeline_stage - th_bias = model_params[prefix + "c_attn.bias"].to(torch.float16).cpu().contiguous() + th_bias = model_params[prefix + "c_attn.bias"].to(torch_dtype).cpu().contiguous() q_emb = th_bias.shape[0] // 3 th_bias = th_bias.reshape(3, q_emb) split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1) @@ -838,7 +838,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].attention.qkv.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].attention.qkv.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.to(torch.float16).numpy() + tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.to(torch_dtype).numpy() torch_dtype = str_dtype_to_torch(dtype) @@ -885,7 +885,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].attention.dense.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].attention.dense.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.to(torch.float16).numpy() + tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.to(torch_dtype).numpy() elif "mlp.w1.qweight" in k: split_v_suf = [] for suf in suffixs: @@ -899,7 +899,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].mlp.w1.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].mlp.w1.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.to(torch.float16).numpy() + tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.to(torch_dtype).numpy() elif "mlp.c_proj.qweight" in k: split_v_suf = [] for suf in suffixs: @@ -913,7 +913,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].mlp.c_proj.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].mlp.c_proj.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.to(torch.float16).numpy() + tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.to(torch_dtype).numpy() elif "mlp.w2.qweight" in k: split_v_suf = [] for suf in suffixs: @@ -927,7 +927,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].mlp.w2.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].mlp.w2.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.to(torch.float16).numpy() + tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.to(torch_dtype).numpy() tok = time.time() t = time.strftime("%h:%m:%s", time.gmtime(tok - tik))