From 6d77b31b62ccb471ab8291734d751c0fdca0d7b3 Mon Sep 17 00:00:00 2001 From: myl_fm Date: Thu, 24 Oct 2024 19:18:49 +0800 Subject: [PATCH] Fix errors when using smoothquant to quantize Qwen2 model When quantizing the Qwen2 model with SmoothQuant, the mlp.proj was not correctly split; it should be split along the 0th dimension. command: ``` python3 convert_checkpoint.py --model_dir ./tmp/Qwen/7B/ --output_dir ./tllm_checkpoint_1gpu_sq --dtype float16 --smoothquant 0.5 trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_sq \ --output_dir ./engine_outputs \ --gemm_plugin float16 ``` --- tensorrt_llm/models/qwen/convert.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index 8cfaf58e5..917b279c7 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -297,8 +297,8 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1) q_split = torch.split(q, q.shape[-1] // tp_size, dim=-1) - k_split = torch.split(k, q.shape[-1] // tp_size, dim=-1) - v_split = torch.split(v, q.shape[-1] // tp_size, dim=-1) + k_split = torch.split(k, k.shape[-1] // tp_size, dim=-1) + v_split = torch.split(v, v.shape[-1] // tp_size, dim=-1) return [ torch.concat((q_split[ii], k_split[ii], v_split[ii]), dim=-1) for ii in range(tp_size) @@ -318,8 +318,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = torch.split(original_weights, - original_weights.shape[-1] // + cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[rank] if is_qkv: @@ -370,8 +369,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = torch.split(original_weights, - original_weights.shape[-1] // + cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[rank] if is_qkv: @@ -823,7 +821,7 @@ def convert_hf_qwen(hf_model, 1, intermediate_size // tensor_parallel ], rank=mapping.tp_rank, - cat_dim=-1)) + cat_dim=0)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.',