From d8e85043b170227a3c05f670fa06bd98f5060a01 Mon Sep 17 00:00:00 2001 From: Tlntin Date: Wed, 6 Dec 2023 22:09:11 +0800 Subject: [PATCH] support build tensorRT engine from qwen-xxx-chat-int4 --- README.md | 25 ++++++++++++++++++++---- qwen/README.md | 52 +++++++++++++++++++++++++++++++++++++++++++------- qwen/build.py | 6 +++--- qwen/weight.py | 21 +++++++++++++++++--- 4 files changed, 87 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 337c4b41..efba13ec 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,9 @@ ### 更新说明 +#### 2023/12/06 更新 +1. 支持Qwen-xxx-Chat-Int4模型直接编译成TensorRT Engine。 +2. 修复awq多卡qkv bias部分报错。 #### 2023/11/22 更新 1. 新增chatglm3-6b-32k模型支持,chatglm3-6b-32k与chatglm3-6b相比不同之处在于位置编码的rope_ratio不同,[文档链接](./chatglm3-6b-32k/README.md) @@ -292,20 +295,21 @@ python3 build.py --use_weight_only --weight_only_precision=int8 --int8_kv_cache ##### 运行指南(int4-gptq篇) 1. 需要安装[auto-gptq](https://github.com/PanQiWei/AutoGPTQ)模块,并且升级transformers模块版本,最低要求4.32.0。(注:安装完模块后可能会提示tensorrt_llm与其他模块版本不兼容,可以忽略该警告) ```bash -pip install auto-gptq +pip install auto-gptq optimum pip install transformers -U ``` -2. 转权重获取scale相关信息,默认使用GPU进行校准,需要能够完整加载模型。(注:对于Qwen-7B-Chat V1.0,可以加上`--device=cpu`来尝试用cpu标定,但是时间会很长) +2. 手动获取标定权重(可选) +- 转权重获取scale相关信息,默认使用GPU进行校准,需要能够完整加载模型。(注:对于Qwen-7B-Chat V1.0,可以加上`--device=cpu`来尝试用cpu标定,但是时间会很长) ```bash python3 gptq_convert.py ``` -3. 编译TensorRT-LLM Engine +- 编译TensorRT-LLM Engine ```bash python build.py --use_weight_only \ --weight_only_precision int4_gptq \ --per_group ``` -4. 如果想要节省显存(注:只能用于单batch),可以试试加上这俩参数来编译Engine +- 如果想要节省显存(注:只能用于单batch),可以试试加上这俩参数来编译Engine ```bash python build.py --use_weight_only \ --weight_only_precision int4_gptq \ @@ -313,6 +317,19 @@ python build.py --use_weight_only \ --remove_input_padding \ --enable_context_fmha ``` +3. 使用官方int4权重,例如Qwen-xx-Chat-Int4模型(推荐) +- 编译模型,注意设置hf模型路径和`--quant_ckpt_path`量化后权重路径均设置为同一个路径,下面是1.8b模型的示例(其他模型也是一样操作) +```bash +python build.py --use_weight_only \ + --weight_only_precision int4_gptq \ + --per_group \ + --hf_model_dir Qwen-1_8B-Chat-Int4 \ + --quant_ckpt_path Qwen-1_8B-Chat-Int4 +``` +- 运行模型,这里需要指定一下tokenizer路径 +```bash +python3 run.py --tokenizer_dir=Qwen-1_8B-Chat-Int4 +``` ##### 运行指南(int4-awq篇) 1. 需要下载并安装[nvidia-ammo](https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.3.0.tar.gz)模块,下面是一个安装代码参考,注意不要安装cuda版,而是安装通用版,否则会有bug。 diff --git a/qwen/README.md b/qwen/README.md index 40c4b7ce..58a56521 100644 --- a/qwen/README.md +++ b/qwen/README.md @@ -13,12 +13,17 @@ The TensorRT-LLM Qwen implementation can be found in [model.py](model.py). The T ## Support Matrix * FP16 - * INT8 & INT4 Weight-Only + * INT8 & INT4 Weight-Only & INT4-AWQ & INT4-GPTQ * INT8 KV CACHE * Tensor Parallel * STRONGLY TYPED +## Support Model +- QWen 1.8b/7b/14b/72b(maybe) +- Qwen 1.8b-chat/7b-chat/14b-chat/72b-chat(maybe) +- Qwen 1.8b-chat-int4/7b-chat-int4/14b-chat-int4/72b-chat-int4(maybe) + ## Usage The TensorRT-LLM Qwen example code locates at [examples/qwen](./). It takes HF weights as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference. @@ -253,18 +258,18 @@ python summarize.py --backend=trt_llm \ To run the GPTQ Qwen example, the following steps are required: 1. You need to install the [auto-gptq](https://github.com/PanQiWei/AutoGPTQ) module and upgrade the transformers module version, with a minimum of 4.32.0. (Note: After installing the module, it may prompt that the tensorrt_llm is not compatible with other module versions, you can ignore this warning) ```bash -pip install auto-gptq +pip install auto-gptq optimum pip install transformers -U ``` - -2. Weight quantization +2. Manually get the quanted weights (optional) +- Weight quantization ```bash python3 gptq_convert.py --hf_model_dir ./tmp/Qwen/7B \ --tokenizer_dir ./tmp/Qwen/7B \ --quant_ckpt_path ./tmp/Qwen/7B/int4-gptq ``` -3. Build TRT-LLM engine: +- Build TRT-LLM engine: ```bash python build.py --hf_model_dir ./tmp/Qwen/7B \ --quant_ckpt_path ./tmp/Qwen/7B/int4-gptq/gptq_model-4bit-128g.safetensors \ @@ -281,14 +286,14 @@ python build.py --hf_model_dir ./tmp/Qwen/7B \ --output_dir ./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu ``` -4. Run int4-gptq +- Run int4-gptq ```bash python3 run.py --max_new_tokens=50 \ --tokenizer_dir ./tmp/Qwen/7B/ \ --engine_dir=./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu ``` -5. Summarize +- Summarize ```bash python summarize.py --backend=trt_llm \ --tokenizer_dir ./tmp/Qwen/7B/ \ @@ -296,6 +301,39 @@ python summarize.py --backend=trt_llm \ --engine_dir ./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu ``` +3. Use official int4 weights, e.g. Qwen-1_8B-Chat-Int4 model(recommended) +- Build TRT-LLM engine: +```bash +python build.py --hf_model_dir Qwen-1_8B-Chat-Int4 \ + --quant_ckpt_path Qwen-1_8B-Chat-Int4 \ + --dtype float16 \ + --remove_input_padding \ + --use_gpt_attention_plugin float16 \ + --enable_context_fmha \ + --use_gemm_plugin float16 \ + --use_weight_only \ + --weight_only_precision int4_gptq \ + --per_group \ + --world_size 1 \ + --tp_size 1 \ + --output_dir ./tmp/Qwen/1.8B/trt_engines/int4-gptq/1-gpu +``` + +- Run int4-gptq +```bash +python3 run.py --max_new_tokens=50 \ + --tokenizer_dir Qwen-1_8B-Chat-Int4 \ + --engine_dir=./tmp/Qwen/1.8B/trt_engines/int4-gptq/1-gpu +``` + +- Summarize +```bash +python summarize.py --backend=trt_llm \ + --tokenizer_dir Qwen-1_8B-Chat-Int4 \ + --data_type fp16 \ + --engine_dir ./tmp/Qwen/1.8B/trt_engines/int4-gptq/1-gpu +``` + #### INT4-AWQ To run the AWQ Qwen example, the following steps are required: diff --git a/qwen/build.py b/qwen/build.py index 9b466dcf..0f867279 100644 --- a/qwen/build.py +++ b/qwen/build.py @@ -19,7 +19,7 @@ from tensorrt_llm.logger import logger from tensorrt_llm.models import ( # fp8_quantize, - smooth_quantize, + # smooth_quantize, weight_only_groupwise_quantize, weight_only_quantize, ) @@ -292,10 +292,10 @@ def parse_arguments(): parser.add_argument( "--weight_only_precision", - const="int4_awq", + const="int4_gptq", type=str, nargs="?", - default="int4_awq", + default="int4_gptq", choices=["int8", "int4", "int4_gptq", "int4_awq"], help="Define the precision for the weights when using weight-only quantization." "You must also use --use_weight_only for that argument to have an impact.", diff --git a/qwen/weight.py b/qwen/weight.py index 76265d77..c48fd882 100644 --- a/qwen/weight.py +++ b/qwen/weight.py @@ -17,6 +17,7 @@ from tensorrt_llm.quantization import QuantMode from model import QWenForCausalLM from tensorrt_llm.mapping import Mapping +from transformers import AutoModelForCausalLM def gen_suffix(rank, use_smooth_quant, quant_per_channel): @@ -733,7 +734,17 @@ def load_from_gptq_qwen( elif quant_ckpt_path.endswith(".pt"): model_params = torch.load(quant_ckpt_path, map_location=torch.device("cpu")) else: - raise ValueError("quantized checkpoint format not supported!") + if os.path.isdir(quant_ckpt_path): + model = AutoModelForCausalLM.from_pretrained( + quant_ckpt_path, + device_map="auto", + trust_remote_code=True + ).eval().cpu() + model_params = {k: v for k, v in model.state_dict().items()} + torch.cuda.empty_cache() + del model + else: + raise ValueError("quantized checkpoint format not supported!") def unpack_int32_into_int8(w_packed): # unpack inputs packed in int32/float32 into uint4 and store them in int8 format @@ -1107,8 +1118,12 @@ def process_and_assign_weight(model_params, mPrefix, mOp, tp_dim=0): process_and_assign_weight(model_params, mPrefix, mOp, 1) # Attention QKV Liner Bias - qkv_bias = model_params[prefix + "attn.c_attn.bias"].cpu().to(torch_dtype).contiguous() - tensorrt_llm_qwen.layers[layer_idx].attention.qkv.bias.value = qkv_bias.numpy() + th_bias = model_params[prefix + "attn.c_attn.bias"].cpu().to(torch_dtype).contiguous() + q_emb = th_bias.shape[0] // 3 + th_bias = th_bias.reshape(3, q_emb) + split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1) + split_v = split_v.reshape(3 * (q_emb // mapping.tp_size)) + tensorrt_llm_qwen.layers[layer_idx].attention.qkv.bias.value = np.ascontiguousarray(split_v) # Attention Dense (out_proj) Linear mPrefix = prefix + "attn.c_proj"