From 95511613595c2fca8cd9071faf3695efe0d4da4a Mon Sep 17 00:00:00 2001 From: zhaohb Date: Thu, 23 Nov 2023 15:47:03 +0000 Subject: [PATCH] fix some bug for gptq and awq --- qwen/quantize.py | 6 +++--- qwen/weight.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/qwen/quantize.py b/qwen/quantize.py index 59e79205..a7eb9425 100644 --- a/qwen/quantize.py +++ b/qwen/quantize.py @@ -38,7 +38,7 @@ -def get_calib_dataloader(data="cnn_dailymail", +def get_calib_dataloader(data="ccdv/cnn_dailymail", tokenizer=None, batch_size=1, calib_size=512, @@ -50,8 +50,8 @@ def get_calib_dataloader(data="cnn_dailymail", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train") dataset = dataset["text"][:calib_size] - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") + elif data == "ccdv/cnn_dailymail": + dataset = load_dataset("ccdv/cnn_dailymail", name="3.0.0", split="train") dataset = dataset["article"][:calib_size] else: raise NotImplementedError diff --git a/qwen/weight.py b/qwen/weight.py index 700a9098..7a9dc752 100644 --- a/qwen/weight.py +++ b/qwen/weight.py @@ -823,8 +823,13 @@ def preprocess_groupwise_weight_params( split_qkv_suf.append(split_qkv) idx = layer - mapping.pp_rank * layers_per_pipeline_stage - th_bias = model_params[prefix + "c_attn.bias"].cpu().contiguous() - tensorrt_llm_qwen.layers[idx].attention.qkv.bias.value = th_bias.numpy() + th_bias = model_params[prefix + "c_attn.bias"].to(torch.float16).cpu().contiguous() + q_emb = th_bias.shape[0] // 3 + th_bias = th_bias.reshape(3, q_emb) + split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1) + split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) + + tensorrt_llm_qwen.layers[idx].attention.qkv.bias.value = np.ascontiguousarray(split_v) th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( None, split_qkv_suf[0], @@ -833,7 +838,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].attention.qkv.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].attention.qkv.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.numpy() + tensorrt_llm_qwen.layers[idx].attention.qkv.scale.value = th_scale.to(torch.float16).numpy() torch_dtype = str_dtype_to_torch(dtype) @@ -880,7 +885,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].attention.dense.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].attention.dense.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.numpy() + tensorrt_llm_qwen.layers[idx].attention.dense.scale.value = th_scale.to(torch.float16).numpy() elif "mlp.w1.qweight" in k: split_v_suf = [] for suf in suffixs: @@ -894,7 +899,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].mlp.w1.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].mlp.w1.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.numpy() + tensorrt_llm_qwen.layers[idx].mlp.w1.scale.value = th_scale.to(torch.float16).numpy() elif "mlp.c_proj.qweight" in k: split_v_suf = [] for suf in suffixs: @@ -908,7 +913,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].mlp.c_proj.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].mlp.c_proj.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.numpy() + tensorrt_llm_qwen.layers[idx].mlp.c_proj.scale.value = th_scale.to(torch.float16).numpy() elif "mlp.w2.qweight" in k: split_v_suf = [] for suf in suffixs: @@ -922,7 +927,7 @@ def preprocess_groupwise_weight_params( ) tensorrt_llm_qwen.layers[idx].mlp.w2.qweight.value = th_qweight.numpy() tensorrt_llm_qwen.layers[idx].mlp.w2.zero.value = th_zero.numpy() - tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.numpy() + tensorrt_llm_qwen.layers[idx].mlp.w2.scale.value = th_scale.to(torch.float16).numpy() tok = time.time() t = time.strftime("%h:%m:%s", time.gmtime(tok - tik))