Skip to content

Commit

Permalink
fix some bug for gptq and awq
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaohb committed Nov 23, 2023
1 parent fff22e2 commit 9551161
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
6 changes: 3 additions & 3 deletions qwen/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@



def get_calib_dataloader(data="cnn_dailymail",
def get_calib_dataloader(data="ccdv/cnn_dailymail",
tokenizer=None,
batch_size=1,
calib_size=512,
Expand All @@ -50,8 +50,8 @@ def get_calib_dataloader(data="cnn_dailymail",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
dataset = dataset["text"][:calib_size]
elif data == "cnn_dailymail":
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
elif data == "ccdv/cnn_dailymail":
dataset = load_dataset("ccdv/cnn_dailymail", name="3.0.0", split="train")
dataset = dataset["article"][:calib_size]
else:
raise NotImplementedError
Expand Down
19 changes: 12 additions & 7 deletions qwen/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,13 @@ def preprocess_groupwise_weight_params(
split_qkv_suf.append(split_qkv)

idx = layer - mapping.pp_rank * layers_per_pipeline_stage
th_bias = model_params[prefix + "c_attn.bias"].cpu().contiguous()
tensorrt_llm_qwen.layers[idx].attention.qkv.bias.value = th_bias.numpy()
th_bias = model_params[prefix + "c_attn.bias"].to(torch.float16).cpu().contiguous()
q_emb = th_bias.shape[0] // 3
th_bias = th_bias.reshape(3, q_emb)
split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1)
split_v = split_v.reshape(3 * (q_emb // mapping.tp_size))

tensorrt_llm_qwen.layers[idx].attention.qkv.bias.value = np.ascontiguousarray(split_v)
th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
None,
split_qkv_suf[0],
Expand All @@ -833,7 +838,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].attention.qkv.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].attention.qkv.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.numpy()
tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.to(torch.float16).numpy()

torch_dtype = str_dtype_to_torch(dtype)

Expand Down Expand Up @@ -880,7 +885,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].attention.dense.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].attention.dense.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.numpy()
tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.to(torch.float16).numpy()
elif "mlp.w1.qweight" in k:
split_v_suf = []
for suf in suffixs:
Expand All @@ -894,7 +899,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].mlp.w1.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w1.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.to(torch.float16).numpy()
elif "mlp.c_proj.qweight" in k:
split_v_suf = []
for suf in suffixs:
Expand All @@ -908,7 +913,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].mlp.c_proj.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].mlp.c_proj.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.numpy()
tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.to(torch.float16).numpy()
elif "mlp.w2.qweight" in k:
split_v_suf = []
for suf in suffixs:
Expand All @@ -922,7 +927,7 @@ def preprocess_groupwise_weight_params(
)
tensorrt_llm_qwen.layers[idx].mlp.w2.qweight.value = th_qweight.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w2.zero.value = th_zero.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.numpy()
tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.to(torch.float16).numpy()

tok = time.time()
t = time.strftime("%h:%m:%s", time.gmtime(tok - tik))
Expand Down

0 comments on commit 9551161

Please sign in to comment.