Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaohb committed Nov 24, 2023
1 parent 9551161 commit 7b42608
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions qwen/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def preprocess_groupwise_weight_params(
split_qkv_suf.append(split_qkv)

idx = layer - mapping.pp_rank * layers_per_pipeline_stage
th_bias = model_params[prefix + "c_attn.bias"].to(torch.float16).cpu().contiguous()
th_bias = model_params[prefix + "c_attn.bias"].to(torch_dtype).cpu().contiguous()
q_emb = th_bias.shape[0] // 3
th_bias = th_bias.reshape(3, q_emb)
split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1)
Expand All @@ -838,7 +838,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].attention.qkv.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].attention.qkv.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.to(torch.float16).numpy()
tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.to(torch_dtype).numpy()

torch_dtype = str_dtype_to_torch(dtype)

Expand Down Expand Up @@ -885,7 +885,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].attention.dense.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].attention.dense.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.to(torch.float16).numpy()
tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.to(torch_dtype).numpy()
elif "mlp.w1.qweight" in k:
split_v_suf = []
for suf in suffixs:
Expand All @@ -899,7 +899,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].mlp.w1.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w1.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.to(torch.float16).numpy()
tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.to(torch_dtype).numpy()
elif "mlp.c_proj.qweight" in k:
split_v_suf = []
for suf in suffixs:
Expand All @@ -913,7 +913,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].mlp.c_proj.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].mlp.c_proj.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.to(torch.float16).numpy()
tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.to(torch_dtype).numpy()
elif "mlp.w2.qweight" in k:
split_v_suf = []
for suf in suffixs:
Expand All @@ -927,7 +927,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].mlp.w2.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w2.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.to(torch.float16).numpy()
tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.to(torch_dtype).numpy()

tok = time.time()
t = time.strftime("%h:%m:%s", time.gmtime(tok - tik))
Expand Down

0 comments on commit 7b42608

Please sign in to comment.