Skip to content

Commit

Permalink
support dynamic_dims when model prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 30, 2024
1 parent d8b2da8 commit c5ab062
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 50 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,38 @@
--output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx"
```

4. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。
4. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。
```bash
python3 export/onnx2om.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" \
--om_model_path="./output/model/qwen2_1.5b_chat" \
--kv_cache_length=1024
--kv_cache_length=1024 \
--max_prefill_length=8
```


##### 步骤2:运行模型
- 使用下面的命令直接运行模型
- 使用下面的命令直接运行模型`--max_prefill_length`需要和上面编译的时候使用的数值相同。
```bash
python3 ./cli_chat.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--om_model_path="./output/model/qwen2_1.5b_chat.om"
--om_model_path="./output/model/qwen2_1.5b_chat.om" \
--max_prefill_length=8
```

- demo展示(演示模型,qwen1.5-0.5b-chat)
- demo展示1(演示模型,qwen1.5-0.5b-chat,未开启动态shape推理
![](./image/qwen1.5_0.5b_chat.gif)

- demo展示2(演示模型,qwen2-1.5b-instruct,开启动态shape推理, max_prefill_length=8)
![](./image/qwen2-1.5b-instruct.gif)


### 当前功能
- [x] 导出onnx, om模型
- [x] 模型推理,支持onnx推理(仅支持CPU)。
- [x] 模型推理,支持acl推理。
- [x] 模型推理,支持CANN推理。
- [x] CANN推理时使用动态shape推理以降低首字延迟。
- [x] 流式传输
- [ ] 兼容OpenAI的api搭建
- [ ] 支持functional call
Expand Down
20 changes: 20 additions & 0 deletions cli_chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
import math
import argparse
from concurrent.futures import ThreadPoolExecutor
from config import InferenceConfig
Expand Down Expand Up @@ -32,12 +33,27 @@
type=str,
default= os.path.join(project_dir, "output", "model", "qwen2_1.5b_chat.om")
)
parser.add_argument(
"--max_batch",
help="max batch",
type=int,
default=1,
)
parser.add_argument(
"--max_input_length",
help="max input length",
type=int,
default=512,
)
parser.add_argument(
"--max_prefill_length",
help="max prefill length in first inference. "
"Attention max_prefill_length + max_output_length <= kv_cache_length. "
"the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... "
"Note! The higher this number, the longer it will take to compile.",
type=int,
default=8,
)

parser.add_argument(
"--max_output_length",
Expand All @@ -47,14 +63,18 @@
)

args = parser.parse_args()
max_prefill_log2 = int(math.log2(args.max_prefill_length))
max_prefill_length = 2 ** max_prefill_log2
config = InferenceConfig(
hf_model_dir=args.hf_model_dir,
om_model_path=args.om_model_path,
onnx_model_path=args.onnx_model_path,
session_type=args.session_type,
max_batch=args.max_batch,
max_output_length=args.max_output_length,
max_input_length=args.max_input_length,
kv_cache_length=args.max_output_length,
max_prefill_length=max_prefill_length,
)
infer_engine=Inference(config)

Expand Down
7 changes: 6 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def __init__(
sampling_method: str = "top_k",
sampling_value: float = 10,
temperature: float = 0.7,
max_batch: int = 1,
max_input_length: int = 512, # 输入长度的最大数值
max_output_length: int = 1024, # 输出长度的最大值
max_prefill_length: int = 1, # prefile阶段,单次最大推理长度
kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看
kv_cache_length: int = 1024, # kvcache的最大长度
cache_format: str = 'huggingface-tensor', # kv_cache的格式
Expand All @@ -32,6 +34,7 @@ def __init__(
self.sampling_method = sampling_method
self.sampling_value = sampling_value
self.temperature = temperature
self.max_batch = max_batch
self.max_input_length = max_input_length
self.max_output_length = max_output_length
self.kvcache_method = kvcache_method
Expand All @@ -47,8 +50,10 @@ def __init__(
self.past_key_value_shape = (
self.num_hidden_layers,
2,
1,
self.max_batch,
self.num_key_value_heads,
self.kv_cache_length,
self.per_head_dim
)
self.max_prefill_length = max_prefill_length
self.vocab_size = self.model_config.vocab_size
68 changes: 58 additions & 10 deletions export/onnx2om.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import subprocess
import argparse
import math
from transformers.models.qwen2 import Qwen2Config

now_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -34,6 +35,21 @@
type=str,
default= os.path.join(model_dir, "qwen2_1.5b_chat")
)
parser.add_argument(
"--max_batch",
help="max batch",
type=int,
default=1,
)
parser.add_argument(
"--max_prefill_length",
help="max prefill length in first inference. "
"Attention max_prefill_length + max_output_length <= kv_cache_length. "
"the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... "
"Note! The higher this number, the longer it will take to compile.",
type=int,
default=8,
)
parser.add_argument(
"--kv_cache_length",
help="kv-cache length",
Expand Down Expand Up @@ -68,27 +84,52 @@ def get_soc_version():
print("SoC Version is ", soc_version)
return soc_version


max_batch = args.max_batch
model_config = Qwen2Config.from_pretrained(args.hf_model_dir)
num_hidden_layers = model_config.num_hidden_layers
num_key_value_heads = model_config.num_key_value_heads
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
per_head_dim = hidden_size // num_attention_heads
kv_cache_length = args.kv_cache_length
batch_size = 1
seq_len = 1
all_len = seq_len + kv_cache_length
attention_mask_shape = [batch_size, all_len]
max_prefill_log2 = int(math.log2(args.max_prefill_length))
max_prefill_length = 2 ** max_prefill_log2
prefill_length_range = list(range(0, max_prefill_log2 + 1))
prefill_length_range = [2 ** idx for idx in prefill_length_range]
assert (max_prefill_length < kv_cache_length), \
print("max_input_length max be smaller than kv_cache_length, because max_input_length + max_output_length <= kv_cache")
input_ids_length_range = prefill_length_range
attention_length_range = [
length + kv_cache_length
for length in prefill_length_range
]
position_length_range = prefill_length_range
input_ids_shape = [
f"1~{max_batch}" if max_batch > 1 else "1",
"-1" if max_prefill_length > 1 else "1",
]
attention_mask_shape = [
f"1~{max_batch}" if max_batch > 1 else "1",
"-1" if max_prefill_length > 1 else str(1 + kv_cache_length)
]
position_ids_shape = [
f"1~{max_batch}" if max_batch > 1 else "1",
"-1" if max_prefill_length > 1 else "1"
]
dynamic_dims = []
for dynamic_dim in zip(
input_ids_length_range, attention_length_range, position_length_range
):
dynamic_dim = [str(dim) for dim in dynamic_dim]
dynamic_dims.append(",".join(dynamic_dim))
past_key_values_shape = [
num_hidden_layers,
2,
1,
f"1~{max_batch}" if max_batch > 1 else "1",
num_key_value_heads,
kv_cache_length,
per_head_dim
]
attention_mask_shape = [str(x) for x in attention_mask_shape]
past_key_values_shape = [str(x) for x in past_key_values_shape]

command_lines = [
Expand All @@ -99,10 +140,17 @@ def get_soc_version():
"--soc_version=Ascend{}".format(get_soc_version()),
"--precision_mode=must_keep_origin_dtype",
"--input_format=ND",
'--input_shape="input_ids:1,1;attention_mask:{};position_ids:1,1;past_key_values:{}"'.format(
",".join(attention_mask_shape), ",".join(past_key_values_shape)
)
'--input_shape="input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}"'.format(
",".join(input_ids_shape),
",".join(attention_mask_shape),
",".join(position_ids_shape),
",".join(past_key_values_shape)
),
]
if max_prefill_length > 1:
command_lines.append(
"--dynamic_dims \"{}\"".format(";".join(dynamic_dims))
)
print("============ run command ==============")
print(" ".join(command_lines))
print("=======================================")
Expand Down
Binary file added image/qwen2-1.5b-instruct.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
100 changes: 90 additions & 10 deletions utils/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, config: InferenceConfig, context=None,callback=None):
self.callback_interval = 1
self.exit_flag = False
self.kv_cache = None
self.max_batch = config.max_batch
self.kv_cache_length = config.kv_cache_length
self.input_dataset, self.output_dataset = None, None
self.inputs:List[Dict[str,]] = []
self.outputs:List[Dict[str,]] = []
Expand Down Expand Up @@ -189,47 +191,125 @@ def free_memory(self):
ret = acl.rt.free_host(item["buffer_host"])
ret = acl.mdl.destroy_dataset(self.output_dataset)

def inference(self,data) -> List[np.ndarray]:
def inference(self, input_data_list: List[np.ndarray], seq_length=1, is_dynamic=False) -> List[np.ndarray]:
"""
执行推理,同步方式
Args:
data (_type_): _description_
input_data_list (_type_): _description_
seq_length: 推理长度
Returns:
List[np.ndarray]: _description_
"""
start = time.time()
acl.rt.set_context(self.context)
for i in range(len(data)):
for i in range(len(input_data_list)):
if i == 3:
pass
continue
else:
bytes_data = data[i].tobytes()
input_data = input_data_list[i]
input_size = input_data.size
input_itemsize = input_data.itemsize
bytes_data = input_data.tobytes()
np_ptr = acl.util.bytes_to_ptr(bytes_data)
if is_dynamic:
input_copy_size = input_size * input_itemsize
else:
input_copy_size = self.inputs[i]["size"]
ret = acl.rt.memcpy(
self.inputs[i]["buffer"],
self.inputs[i]["size"],
np_ptr,
self.inputs[i]["size"],
input_copy_size,
ACL_MEMCPY_HOST_TO_DEVICE
)
check_ret("memcpy", ret)
check_ret("memcpy input", ret)
output_sizes = []
if is_dynamic:
# link https://www.hiascend.com/doc_center/source/zh/canncommercial/80RC1/apiref/appdevgapi/aclpythondevg_01_0159.html
index, ret = acl.mdl.get_input_index_by_name(
self.model_desc, "ascend_mbatch_shape_data"
)
check_ret("get_input_index_by_name", ret)
dynamic_dims = [
# input_ids
self.max_batch,
seq_length,
# attention_mask
self.max_batch,
seq_length + self.kv_cache_length,
# position_ids
self.max_batch,
seq_length
]
dynamic_dims += self.config.past_key_value_shape
# will set dynamic input shape
ret = acl.mdl.set_input_dynamic_dims(
self.model_id,
self.input_dataset,
index,
{
'dimCount': len(dynamic_dims),
'name': '',
'dims': dynamic_dims
}
)
check_ret("set_iniput_dynamic_dims", ret)
output_itemsize1 = np.dtype(self.outputs[0]["dtype"]).itemsize
output_itemsize2 = np.dtype(self.outputs[1]["dtype"]).itemsize
logits_size = self.max_batch * seq_length * self.config.vocab_size
logits_itemsize = logits_size * output_itemsize1
new_kv_cache_size = (
self.config.num_hidden_layers \
* 2 \
* self.max_batch \
* self.config.num_key_value_heads \
* seq_length \
* self.config.per_head_dim \
)
new_kv_cache_itemsize = new_kv_cache_size * output_itemsize2
output_sizes = [logits_size, new_kv_cache_size]
output_itemsizes = [logits_itemsize, new_kv_cache_itemsize]
logits_shape = [self.max_batch, seq_length, self.config.vocab_size]
new_kv_cache_shape = [
self.config.num_hidden_layers,
2,
self.max_batch,
self.config.num_key_value_heads,
seq_length,
self.config.per_head_dim
]
output_shapes = [logits_shape, new_kv_cache_shape]

ret = acl.mdl.execute(
self.model_id,
self.input_dataset,
self.output_dataset
)
check_ret("model_execute", ret)
inference_result = []
for out in self.outputs:

for output_idx, out in enumerate(self.outputs):
if is_dynamic:
output_itemsize = output_itemsizes[output_idx]
output_size = output_sizes[output_idx]
else:
output_itemsize = out["size"]
output_size = output_itemsize // np.dtype(out["dtype"]).itemsize
ret = acl.rt.memcpy(
out['buffer_host'],
out["size"],
out["buffer"],
out["size"],
output_itemsize,
ACL_MEMCPY_DEVICE_TO_HOST
)
check_ret("memcpy output", ret)
bytes_out = acl.util.ptr_to_bytes(out['buffer_host'], out["size"])
out_data = np.frombuffer(bytes_out, dtype=out['dtype'])
out_data = np.frombuffer(
bytes_out,
dtype=out['dtype'],
count=output_size,
).reshape(output_shapes[output_idx])
inference_result.append(out_data)
return inference_result

Expand Down
7 changes: 6 additions & 1 deletion utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ def stream_predict(
max_output_len = self.max_output_length - input_length
for i in range(max_output_len):
logits = self.session.run(input_ids)[0]
input_ids = self.sample_logits(logits[0][-1:], self.sampling_method, self.sampling_value, self.temperature)
input_ids = self.sample_logits(
logits[0][-1:],
self.sampling_method,
self.sampling_value,
self.temperature
)
input_ids = input_ids.reshape(1, -1)
if i == 0:
first_token_latency = time.time() - start
Expand Down
Loading

0 comments on commit c5ab062

Please sign in to comment.