Skip to content

Commit

Permalink
support build tensorRT engine from qwen-xxx-chat-int4
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Dec 6, 2023
1 parent d49bab0 commit d8e8504
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 17 deletions.
25 changes: 21 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
</details>

### 更新说明
#### 2023/12/06 更新
1. 支持Qwen-xxx-Chat-Int4模型直接编译成TensorRT Engine。
2. 修复awq多卡qkv bias部分报错。

#### 2023/11/22 更新
1. 新增chatglm3-6b-32k模型支持,chatglm3-6b-32k与chatglm3-6b相比不同之处在于位置编码的rope_ratio不同,[文档链接](./chatglm3-6b-32k/README.md)
Expand Down Expand Up @@ -292,27 +295,41 @@ python3 build.py --use_weight_only --weight_only_precision=int8 --int8_kv_cache
##### 运行指南(int4-gptq篇)
1. 需要安装[auto-gptq](https://github.com/PanQiWei/AutoGPTQ)模块,并且升级transformers模块版本,最低要求4.32.0。(注:安装完模块后可能会提示tensorrt_llm与其他模块版本不兼容,可以忽略该警告)
```bash
pip install auto-gptq
pip install auto-gptq optimum
pip install transformers -U
```
2. 转权重获取scale相关信息,默认使用GPU进行校准,需要能够完整加载模型。(注:对于Qwen-7B-Chat V1.0,可以加上`--device=cpu`来尝试用cpu标定,但是时间会很长)
2. 手动获取标定权重(可选)
- 转权重获取scale相关信息,默认使用GPU进行校准,需要能够完整加载模型。(注:对于Qwen-7B-Chat V1.0,可以加上`--device=cpu`来尝试用cpu标定,但是时间会很长)
```bash
python3 gptq_convert.py
```
3. 编译TensorRT-LLM Engine
- 编译TensorRT-LLM Engine
```bash
python build.py --use_weight_only \
--weight_only_precision int4_gptq \
--per_group
```
4. 如果想要节省显存(注:只能用于单batch),可以试试加上这俩参数来编译Engine
- 如果想要节省显存(注:只能用于单batch),可以试试加上这俩参数来编译Engine
```bash
python build.py --use_weight_only \
--weight_only_precision int4_gptq \
--per_group \
--remove_input_padding \
--enable_context_fmha
```
3. 使用官方int4权重,例如Qwen-xx-Chat-Int4模型(推荐)
- 编译模型,注意设置hf模型路径和`--quant_ckpt_path`量化后权重路径均设置为同一个路径,下面是1.8b模型的示例(其他模型也是一样操作)
```bash
python build.py --use_weight_only \
--weight_only_precision int4_gptq \
--per_group \
--hf_model_dir Qwen-1_8B-Chat-Int4 \
--quant_ckpt_path Qwen-1_8B-Chat-Int4
```
- 运行模型,这里需要指定一下tokenizer路径
```bash
python3 run.py --tokenizer_dir=Qwen-1_8B-Chat-Int4
```
##### 运行指南(int4-awq篇)
1. 需要下载并安装[nvidia-ammo](https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.3.0.tar.gz)模块,下面是一个安装代码参考,注意不要安装cuda版,而是安装通用版,否则会有bug。
Expand Down
52 changes: 45 additions & 7 deletions qwen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ The TensorRT-LLM Qwen implementation can be found in [model.py](model.py). The T
## Support Matrix
* FP16
<!-- * FP8 -->
* INT8 & INT4 Weight-Only
* INT8 & INT4 Weight-Only & INT4-AWQ & INT4-GPTQ
<!-- * FP8 KV CACHE -->
* INT8 KV CACHE
* Tensor Parallel
* STRONGLY TYPED

## Support Model
- QWen 1.8b/7b/14b/72b(maybe)
- Qwen 1.8b-chat/7b-chat/14b-chat/72b-chat(maybe)
- Qwen 1.8b-chat-int4/7b-chat-int4/14b-chat-int4/72b-chat-int4(maybe)

## Usage

The TensorRT-LLM Qwen example code locates at [examples/qwen](./). It takes HF weights as input, and builds the corresponding TensorRT engines. The number of TensorRT engines depends on the number of GPUs used to run inference.
Expand Down Expand Up @@ -253,18 +258,18 @@ python summarize.py --backend=trt_llm \
To run the GPTQ Qwen example, the following steps are required:
1. You need to install the [auto-gptq](https://github.com/PanQiWei/AutoGPTQ) module and upgrade the transformers module version, with a minimum of 4.32.0. (Note: After installing the module, it may prompt that the tensorrt_llm is not compatible with other module versions, you can ignore this warning)
```bash
pip install auto-gptq
pip install auto-gptq optimum
pip install transformers -U
```

2. Weight quantization
2. Manually get the quanted weights (optional)
- Weight quantization
```bash
python3 gptq_convert.py --hf_model_dir ./tmp/Qwen/7B \
--tokenizer_dir ./tmp/Qwen/7B \
--quant_ckpt_path ./tmp/Qwen/7B/int4-gptq
```

3. Build TRT-LLM engine:
- Build TRT-LLM engine:
```bash
python build.py --hf_model_dir ./tmp/Qwen/7B \
--quant_ckpt_path ./tmp/Qwen/7B/int4-gptq/gptq_model-4bit-128g.safetensors \
Expand All @@ -281,21 +286,54 @@ python build.py --hf_model_dir ./tmp/Qwen/7B \
--output_dir ./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu
```

4. Run int4-gptq
- Run int4-gptq
```bash
python3 run.py --max_new_tokens=50 \
--tokenizer_dir ./tmp/Qwen/7B/ \
--engine_dir=./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu
```

5. Summarize
- Summarize
```bash
python summarize.py --backend=trt_llm \
--tokenizer_dir ./tmp/Qwen/7B/ \
--data_type fp16 \
--engine_dir ./tmp/Qwen/7B/trt_engines/int4-gptq/1-gpu
```

3. Use official int4 weights, e.g. Qwen-1_8B-Chat-Int4 model(recommended)
- Build TRT-LLM engine:
```bash
python build.py --hf_model_dir Qwen-1_8B-Chat-Int4 \
--quant_ckpt_path Qwen-1_8B-Chat-Int4 \
--dtype float16 \
--remove_input_padding \
--use_gpt_attention_plugin float16 \
--enable_context_fmha \
--use_gemm_plugin float16 \
--use_weight_only \
--weight_only_precision int4_gptq \
--per_group \
--world_size 1 \
--tp_size 1 \
--output_dir ./tmp/Qwen/1.8B/trt_engines/int4-gptq/1-gpu
```

- Run int4-gptq
```bash
python3 run.py --max_new_tokens=50 \
--tokenizer_dir Qwen-1_8B-Chat-Int4 \
--engine_dir=./tmp/Qwen/1.8B/trt_engines/int4-gptq/1-gpu
```

- Summarize
```bash
python summarize.py --backend=trt_llm \
--tokenizer_dir Qwen-1_8B-Chat-Int4 \
--data_type fp16 \
--engine_dir ./tmp/Qwen/1.8B/trt_engines/int4-gptq/1-gpu
```


#### INT4-AWQ
To run the AWQ Qwen example, the following steps are required:
Expand Down
6 changes: 3 additions & 3 deletions qwen/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.models import (
# fp8_quantize,
smooth_quantize,
# smooth_quantize,
weight_only_groupwise_quantize,
weight_only_quantize,
)
Expand Down Expand Up @@ -292,10 +292,10 @@ def parse_arguments():

parser.add_argument(
"--weight_only_precision",
const="int4_awq",
const="int4_gptq",
type=str,
nargs="?",
default="int4_awq",
default="int4_gptq",
choices=["int8", "int4", "int4_gptq", "int4_awq"],
help="Define the precision for the weights when using weight-only quantization."
"You must also use --use_weight_only for that argument to have an impact.",
Expand Down
21 changes: 18 additions & 3 deletions qwen/weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tensorrt_llm.quantization import QuantMode
from model import QWenForCausalLM
from tensorrt_llm.mapping import Mapping
from transformers import AutoModelForCausalLM


def gen_suffix(rank, use_smooth_quant, quant_per_channel):
Expand Down Expand Up @@ -733,7 +734,17 @@ def load_from_gptq_qwen(
elif quant_ckpt_path.endswith(".pt"):
model_params = torch.load(quant_ckpt_path, map_location=torch.device("cpu"))
else:
raise ValueError("quantized checkpoint format not supported!")
if os.path.isdir(quant_ckpt_path):
model = AutoModelForCausalLM.from_pretrained(
quant_ckpt_path,
device_map="auto",
trust_remote_code=True
).eval().cpu()
model_params = {k: v for k, v in model.state_dict().items()}
torch.cuda.empty_cache()
del model
else:
raise ValueError("quantized checkpoint format not supported!")

def unpack_int32_into_int8(w_packed):
# unpack inputs packed in int32/float32 into uint4 and store them in int8 format
Expand Down Expand Up @@ -1107,8 +1118,12 @@ def process_and_assign_weight(model_params, mPrefix, mOp, tp_dim=0):
process_and_assign_weight(model_params, mPrefix, mOp, 1)

# Attention QKV Liner Bias
qkv_bias = model_params[prefix + "attn.c_attn.bias"].cpu().to(torch_dtype).contiguous()
tensorrt_llm_qwen.layers[layer_idx].attention.qkv.bias.value = qkv_bias.numpy()
th_bias = model_params[prefix + "attn.c_attn.bias"].cpu().to(torch_dtype).contiguous()
q_emb = th_bias.shape[0] // 3
th_bias = th_bias.reshape(3, q_emb)
split_v = split(th_bias, mapping.tp_size, mapping.rank, dim=1)
split_v = split_v.reshape(3 * (q_emb // mapping.tp_size))
tensorrt_llm_qwen.layers[layer_idx].attention.qkv.bias.value = np.ascontiguousarray(split_v)

# Attention Dense (out_proj) Linear
mPrefix = prefix + "attn.c_proj"
Expand Down

0 comments on commit d8e8504

Please sign in to comment.