Skip to content

Commit

Permalink
update kv-cache to 2048, reduce memory use, speed onnx runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 24, 2024
1 parent d56c319 commit b6b11e6
Show file tree
Hide file tree
Showing 12 changed files with 365 additions and 89 deletions.
43 changes: 34 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,31 @@
cd qwen-ascend-llm
pip install -r ./requirements.txt
```
2. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数
2. 导出onnx,当前我设置的kv-cache长度为2048,可以根据自己的内存、显存来设置更大参数,最大则不建议超过`max_position_embeddings`这个数,可以去模型里面的config.json文件里面看,qwen2-1.5B-Instruct里面,这个数值为`32768`
```bash
python3 export/export_onnx.py \
--device_str=npu \
--dtype=float16 \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \
--kv_cache_length=1024
--kv_cache_length=2048
```

3. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。
- `--max_input_length`为单次最大可以输入是数据量,该数值必须小于编译onnx的时候指定的`--kv_cache_length`
- `--max_output_length`则必须和之前转onnx的时候指定的`--kv_cache_length`保持一致,否则onnx输出将会异常。
- 注:最大可以生成token数=`max_output_length`-min(max_input_length, 实际输入的token数)
- npu转出的onnx,dtype取float16,cpu转出来的onnx,dtype取float32
- `--cpu_thread`根据你的cpu线程数设置,默认取4
```bash
python3 ./cli_chat.py \
--session_type=onnx \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx"
--onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \
--dtype="float16" \
--cpu_thread=4 \
--max_input_length=1024 \
--max_output_length=2048
```

4. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。
Expand All @@ -86,24 +95,38 @@
--output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx"
```

5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为16,如果设置为1,则不会开启动态shape推理功能。该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1`
5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**
- 运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。
- kv_cache_length长度和第一步导出onnx时的长度保持一致。
- `--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为4,如果设置为1,则不会开启动态shape推理功能。
- 该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1`
- `--kv_cache_length`的数值必须前面转onnx的时候指定的`--kv_cache_length`保持一致,否则大概率会转换失败。
- `--cpu_thread`为转onnx为om时,开启的cpu线程数,默认为1个线程并行编译,如果内存很多(每个线程单独占用一份内存,所以很费内存),可以调高一些。
```bash
python3 export/onnx2om.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" \
--om_model_path="./output/model/qwen2_1.5b_chat" \
--kv_cache_length=1024 \
--max_prefill_length=16
--kv_cache_length=2048 \
--cpu_thread=1 \
--max_prefill_length=4
```


##### 步骤2:在终端运行模型进行对话
- 使用下面的命令直接运行模型,`--max_prefill_length`需要和上面编译的时候使用的数值相同。
- 使用下面的命令直接运行模型
- `--max_prefill_length`需要和上面编译om模型时使用的数值相同。
- `--max_input_length`为单次最大可以输入是数据量,该数值必须小于编译onnx的时候指定的`--kv_cache_length`
- `--max_output_length`则必须和之前转onnx的时候指定的`--kv_cache_length`保持一致,否则onnx输出将会异常。
- 注:最大可以生成token数=`max_output_length`-min(max_input_length, 实际输入的token数)
```bash
python3 ./cli_chat.py \
--session_type="acl" \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--om_model_path="./output/model/qwen2_1.5b_chat.om" \
--max_prefill_length=16
--max_input_length=1024 \
--max_output_length=2048 \
--max_prefill_length=4
```

- demo展示1(演示模型,qwen1.5-0.5b-chat,未开启动态shape推理)
Expand All @@ -119,7 +142,9 @@
python3 ./api.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--om_model_path="./output/model/qwen2_1.5b_chat.om" \
--max_prefill_length=16
--max_input_length=1024 \
--max_output_length=2048 \
--max_prefill_length=4
```

- 进入client目录,可以运行里面的文件请求服务端。
Expand Down
38 changes: 33 additions & 5 deletions cli_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,34 @@ def parser_args():
type=str,
default="acl",
help="acl or onnx",
choices=["acl", "onnx"],
choices=["acl", "onnx", "pytorch"],
)
parser.add_argument(
"--dtype" ,
type=str,
help="support float16/float32, if use CPU, only support fp32",
choices=["float16", "float32"],
default="float32",
)
parser.add_argument(
"--torch_dtype",
type=str,
help="support float16/float32, if use CPU, only support fp32",
choices=["float16", "float32"],
default="float32",
)
parser.add_argument(
"--device_str",
type=str,
help="support cpu, cuda, npu, only activate when sesstion_type is pytorch",
choices=["cpu", "cuda", "npu"],
default="cpu",
)
parser.add_argument(
"--cpu_thread" ,
type=int,
help="num of cpu thread when run onnx sesstion",
default=4,
)
parser.add_argument(
'--onnx_model_path',
Expand All @@ -44,7 +71,7 @@ def parser_args():
"--max_input_length",
help="max input length",
type=int,
default=512,
default=1024,
)
parser.add_argument(
"--max_prefill_length",
Expand All @@ -53,14 +80,13 @@ def parser_args():
"the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... "
"Note! The higher this number, the longer it will take to compile.",
type=int,
default=16,
default=4,
)

parser.add_argument(
"--max_output_length",
help="max output length (contain input + new token)",
type=int,
default=1024,
default=2048,
)
return parser.parse_args()

Expand Down Expand Up @@ -110,12 +136,14 @@ def inference_cli():
hf_model_dir=args.hf_model_dir,
om_model_path=args.om_model_path,
onnx_model_path=args.onnx_model_path,
cpu_thread=args.cpu_thread,
session_type=args.session_type,
max_batch=args.max_batch,
max_output_length=args.max_output_length,
max_input_length=args.max_input_length,
kv_cache_length=args.max_output_length,
max_prefill_length=max_prefill_length,
dtype=args.dtype
)
# main()
inference_cli()
21 changes: 17 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
from transformers.models.qwen2 import Qwen2Config, Qwen2Tokenizer


Expand All @@ -9,28 +10,33 @@ def __init__(
om_model_path: str,
onnx_model_path: str,
cpu_thread: int = 4, # CPU线程数
session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language
session_type: str = "acl", # 支持acl和onnx, pytorch三种,acl即Ascend C Language
device_id: int = 0,
sampling_method: str = "top_p", # 支持 greedy, top_p, top_k
sampling_value: float = 0.8,
temperature: float = 0.7,
max_batch: int = 1,
max_input_length: int = 512, # 输入长度的最大数值
max_output_length: int = 1024, # 输出长度的最大值
max_input_length: int = 1024, # 输入长度的最大数值
max_output_length: int = 2048, # 输出长度的最大值
max_prefill_length: int = 1, # prefile阶段,单次最大推理长度
kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看
kv_cache_length: int = 1024, # kvcache的最大长度
kv_cache_length: int = 2048, # kvcache的最大长度
cache_format: str = 'huggingface-tensor', # kv_cache的格式
dtype:str="float16",
torch_dtype: str = "float16",
device_str = "cpu",
):
self.tokenizer_dir = hf_model_dir
self.session_type = session_type
if self.session_type == "acl":
assert os.path.exists(om_model_path), print(om_model_path, "not exists")
elif self.session_type == "onnx":
assert os.path.exists(onnx_model_path), print(onnx_model_path, "not exists")
elif self.session_type == "pytorch":
assert os.path.exists(hf_model_dir), print(hf_model_dir, "not exists")
self.om_model_path = om_model_path
self.onnx_model_path = onnx_model_path
self.hf_model_dir = hf_model_dir
self.cpu_thread = cpu_thread
self.device_id = device_id
self.sampling_method = sampling_method
Expand All @@ -43,6 +49,13 @@ def __init__(
self.kv_cache_length = kv_cache_length # max_cache_size
self.cache_format = cache_format
self.dtype = dtype
if torch_dtype == "float16":
self.torch_dtype = torch.float16
elif torch_dtype == "float32":
self.torch_dtype = torch.float32
else:
self.torch_type = "auto"
self.device_str = device_str
self.model_config = Qwen2Config.from_pretrained(hf_model_dir)
self.num_hidden_layers = self.model_config.num_hidden_layers # n_layer
self.num_key_value_heads = self.model_config.num_key_value_heads # head_num
Expand Down
3 changes: 3 additions & 0 deletions export/change_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
new_onnx_dir = os.path.join(output_dir, "onnx2")
if not os.path.exists(new_onnx_dir):
os.mkdir(new_onnx_dir)
else:
for file in os.listdir(new_onnx_dir):
os.remove(os.path.join(new_onnx_dir, file))

now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)
Expand Down
4 changes: 2 additions & 2 deletions export/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def parser_arguments():
"--kv_cache_length",
help="kv-cache length",
type=int,
default=1024,
default=2048,
)
return parser.parse_args()

Expand Down Expand Up @@ -123,7 +123,7 @@ def export_onnx(
output_names = ["logits", "out_key_values"]
dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_length"},
"attention_mask": {0: "batch_size", 1: "seq_length+kv_len"},
"attention_mask": {0: "batch_size", 1: "seq_length + kv_len"},
"position_ids": {0: "batch_size", 1: "seq_length"},
"past_key_values": {0: "batch_size", 1: "kv_len"},
}
Expand Down
43 changes: 31 additions & 12 deletions export/onnx2om.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,26 @@
type=int,
default=1,
)
parser.add_argument(
"--cpu_thread" ,
type=int,
help="num of cpu thread when convert onnx to om",
default=1,
)
parser.add_argument(
"--max_prefill_length",
help="max prefill length in first inference. "
"Attention max_prefill_length + max_output_length <= kv_cache_length. "
"the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... "
"Note! The higher this number, the longer it will take to compile.",
type=int,
default=16
default=8
)
parser.add_argument(
"--kv_cache_length",
help="kv-cache length",
type=int,
default=1024,
default=2048,
)


Expand Down Expand Up @@ -114,10 +120,10 @@ def get_soc_version():
assert (max_prefill_length < kv_cache_length), \
print("max_input_length max be smaller than kv_cache_length, because max_input_length + max_output_length <= kv_cache")
input_ids_length_range = prefill_length_range
attention_length_range = [
length + kv_cache_length
for length in prefill_length_range
]
# attention_length_range = [
# length + kv_cache_length
# for length in prefill_length_range
# ]
position_length_range = prefill_length_range
input_ids_shape = [
f"1~{max_batch}" if max_batch > 1 else "1",
Expand All @@ -132,14 +138,21 @@ def get_soc_version():
"-1" if max_prefill_length > 1 else "1"
]
dynamic_dims = []
for dynamic_dim in zip(
input_ids_length_range, attention_length_range, position_length_range
):
dynamic_dim = [str(dim) for dim in dynamic_dim]
dynamic_dims.append(",".join(dynamic_dim))
for dynamic_kv_cache_length in [
kv_cache_length // 2,
kv_cache_length
]:
for dynamic_dim in zip(input_ids_length_range, position_length_range):
new_dynamic_dim = [
str(dynamic_dim[0]), # input_ids
str(dynamic_dim[0] + dynamic_kv_cache_length), # attention_mask_shape
str(dynamic_dim[1]), # position_ids
str(dynamic_kv_cache_length), # past_key_values
]
dynamic_dims.append(",".join(new_dynamic_dim))
past_key_values_shape = [
f"1~{max_batch}" if max_batch > 1 else "1",
kv_cache_length,
"-1" if max_prefill_length > 1 else kv_cache_length,
num_hidden_layers * 2 * num_key_value_heads,
per_head_dim
]
Expand All @@ -152,8 +165,14 @@ def get_soc_version():
else:
soc_version = args.soc_version
command_lines = [
# reduce memory useage
"export MS_DEV_FORCE_ACL=1 && ",
"export MS_ENABLE_GE=1 && ",
"export TE_PARALLEL_COMPILER={} &&".format(args.cpu_thread),
"export MAX_COMPILE_CORE_NUMBER={} &&".format(args.cpu_thread),
"atc",
"--framework=5",
"--host_env_cpu=aarch64",
'--model="{}"'.format(args.onnx_model_path),
'--output="{}"'.format(args.om_model_path),
"--soc_version={}".format(soc_version),
Expand Down
14 changes: 12 additions & 2 deletions export/test_onnx_run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time
import numpy as np
import onnxruntime
import argparse
Expand Down Expand Up @@ -28,6 +29,12 @@
type=str,
default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx")
)
parser.add_argument(
"--kv_cache_length",
help="kv-cache length",
type=int,
default=2048,
)
args = parser.parse_args()

if args.dtype == "float16":
Expand All @@ -38,7 +45,7 @@
raise Exception("not support dtype, only support float16/float32")


def create_kv_cache(config: Qwen2Config, kv_cache_length=1024):
def create_kv_cache(config: Qwen2Config, kv_cache_length=args.kv_cache_length):
return np.zeros(
[
1,
Expand All @@ -50,7 +57,7 @@ def create_kv_cache(config: Qwen2Config, kv_cache_length=1024):
)


def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = 1024):
def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = args.kv_cache_length):
"""
获取指定长度的kv_cache, 顺便生成mask和position_id
Args:
Expand Down Expand Up @@ -122,12 +129,15 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size
print("now_kv_cache shape: ", now_kv_cache.shape)
print("attention_mask shape: ", attn_mask.shape)
print("position_ids shape: ", position_ids.shape)
st = time.time()
outputs = llm_session.run(None, {
"input_ids": input_ids[:, :1],
"attention_mask": attn_mask,
"position_ids": position_ids,
"past_key_values": now_kv_cache,
})
et = time.time()
print("duration: ", et - st)
print("==== onnx runtime ====")
print("output length: ", len(outputs))
logits = outputs[0]
Expand Down
Loading

0 comments on commit b6b11e6

Please sign in to comment.