diff --git a/README.md b/README.md index c69442a..ffa67c7 100644 --- a/README.md +++ b/README.md @@ -41,32 +41,38 @@ --output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" ``` -4. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。 +4. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。 ```bash python3 export/onnx2om.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ --onnx_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" \ --om_model_path="./output/model/qwen2_1.5b_chat" \ - --kv_cache_length=1024 + --kv_cache_length=1024 \ + --max_prefill_length=8 ``` ##### 步骤2:运行模型 -- 使用下面的命令直接运行模型 +- 使用下面的命令直接运行模型,`--max_prefill_length`需要和上面编译的时候使用的数值相同。 ```bash python3 ./cli_chat.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ - --om_model_path="./output/model/qwen2_1.5b_chat.om" + --om_model_path="./output/model/qwen2_1.5b_chat.om" \ + --max_prefill_length=8 ``` -- demo展示(演示模型,qwen1.5-0.5b-chat) +- demo展示1(演示模型,qwen1.5-0.5b-chat,未开启动态shape推理) ![](./image/qwen1.5_0.5b_chat.gif) +- demo展示2(演示模型,qwen2-1.5b-instruct,开启动态shape推理, max_prefill_length=8) +![](./image/qwen2-1.5b-instruct.gif) + ### 当前功能 - [x] 导出onnx, om模型 - [x] 模型推理,支持onnx推理(仅支持CPU)。 -- [x] 模型推理,支持acl推理。 +- [x] 模型推理,支持CANN推理。 +- [x] CANN推理时使用动态shape推理以降低首字延迟。 - [x] 流式传输 - [ ] 兼容OpenAI的api搭建 - [ ] 支持functional call diff --git a/cli_chat.py b/cli_chat.py index 986e335..d6809fc 100644 --- a/cli_chat.py +++ b/cli_chat.py @@ -1,4 +1,5 @@ import sys +import math import argparse from concurrent.futures import ThreadPoolExecutor from config import InferenceConfig @@ -32,12 +33,27 @@ type=str, default= os.path.join(project_dir, "output", "model", "qwen2_1.5b_chat.om") ) +parser.add_argument( + "--max_batch", + help="max batch", + type=int, + default=1, +) parser.add_argument( "--max_input_length", help="max input length", type=int, default=512, ) +parser.add_argument( + "--max_prefill_length", + help="max prefill length in first inference. " + "Attention max_prefill_length + max_output_length <= kv_cache_length. " + "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " + "Note! The higher this number, the longer it will take to compile.", + type=int, + default=8, +) parser.add_argument( "--max_output_length", @@ -47,14 +63,18 @@ ) args = parser.parse_args() +max_prefill_log2 = int(math.log2(args.max_prefill_length)) +max_prefill_length = 2 ** max_prefill_log2 config = InferenceConfig( hf_model_dir=args.hf_model_dir, om_model_path=args.om_model_path, onnx_model_path=args.onnx_model_path, session_type=args.session_type, + max_batch=args.max_batch, max_output_length=args.max_output_length, max_input_length=args.max_input_length, kv_cache_length=args.max_output_length, + max_prefill_length=max_prefill_length, ) infer_engine=Inference(config) diff --git a/config.py b/config.py index 5181cbd..4e9348e 100644 --- a/config.py +++ b/config.py @@ -13,8 +13,10 @@ def __init__( sampling_method: str = "top_k", sampling_value: float = 10, temperature: float = 0.7, + max_batch: int = 1, max_input_length: int = 512, # 输入长度的最大数值 max_output_length: int = 1024, # 输出长度的最大值 + max_prefill_length: int = 1, # prefile阶段,单次最大推理长度 kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看 kv_cache_length: int = 1024, # kvcache的最大长度 cache_format: str = 'huggingface-tensor', # kv_cache的格式 @@ -32,6 +34,7 @@ def __init__( self.sampling_method = sampling_method self.sampling_value = sampling_value self.temperature = temperature + self.max_batch = max_batch self.max_input_length = max_input_length self.max_output_length = max_output_length self.kvcache_method = kvcache_method @@ -47,8 +50,10 @@ def __init__( self.past_key_value_shape = ( self.num_hidden_layers, 2, - 1, + self.max_batch, self.num_key_value_heads, self.kv_cache_length, self.per_head_dim ) + self.max_prefill_length = max_prefill_length + self.vocab_size = self.model_config.vocab_size diff --git a/export/onnx2om.py b/export/onnx2om.py index 84dd96c..cd388fe 100644 --- a/export/onnx2om.py +++ b/export/onnx2om.py @@ -1,6 +1,7 @@ import os import subprocess import argparse +import math from transformers.models.qwen2 import Qwen2Config now_dir = os.path.dirname(os.path.abspath(__file__)) @@ -34,6 +35,21 @@ type=str, default= os.path.join(model_dir, "qwen2_1.5b_chat") ) +parser.add_argument( + "--max_batch", + help="max batch", + type=int, + default=1, +) +parser.add_argument( + "--max_prefill_length", + help="max prefill length in first inference. " + "Attention max_prefill_length + max_output_length <= kv_cache_length. " + "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " + "Note! The higher this number, the longer it will take to compile.", + type=int, + default=8, +) parser.add_argument( "--kv_cache_length", help="kv-cache length", @@ -68,7 +84,7 @@ def get_soc_version(): print("SoC Version is ", soc_version) return soc_version - +max_batch = args.max_batch model_config = Qwen2Config.from_pretrained(args.hf_model_dir) num_hidden_layers = model_config.num_hidden_layers num_key_value_heads = model_config.num_key_value_heads @@ -76,19 +92,44 @@ def get_soc_version(): num_attention_heads = model_config.num_attention_heads per_head_dim = hidden_size // num_attention_heads kv_cache_length = args.kv_cache_length -batch_size = 1 -seq_len = 1 -all_len = seq_len + kv_cache_length -attention_mask_shape = [batch_size, all_len] +max_prefill_log2 = int(math.log2(args.max_prefill_length)) +max_prefill_length = 2 ** max_prefill_log2 +prefill_length_range = list(range(0, max_prefill_log2 + 1)) +prefill_length_range = [2 ** idx for idx in prefill_length_range] +assert (max_prefill_length < kv_cache_length), \ + print("max_input_length max be smaller than kv_cache_length, because max_input_length + max_output_length <= kv_cache") +input_ids_length_range = prefill_length_range +attention_length_range = [ + length + kv_cache_length + for length in prefill_length_range +] +position_length_range = prefill_length_range +input_ids_shape = [ + f"1~{max_batch}" if max_batch > 1 else "1", + "-1" if max_prefill_length > 1 else "1", +] +attention_mask_shape = [ + f"1~{max_batch}" if max_batch > 1 else "1", + "-1" if max_prefill_length > 1 else str(1 + kv_cache_length) +] +position_ids_shape = [ + f"1~{max_batch}" if max_batch > 1 else "1", + "-1" if max_prefill_length > 1 else "1" +] +dynamic_dims = [] +for dynamic_dim in zip( + input_ids_length_range, attention_length_range, position_length_range +): + dynamic_dim = [str(dim) for dim in dynamic_dim] + dynamic_dims.append(",".join(dynamic_dim)) past_key_values_shape = [ num_hidden_layers, 2, - 1, + f"1~{max_batch}" if max_batch > 1 else "1", num_key_value_heads, kv_cache_length, per_head_dim ] -attention_mask_shape = [str(x) for x in attention_mask_shape] past_key_values_shape = [str(x) for x in past_key_values_shape] command_lines = [ @@ -99,10 +140,17 @@ def get_soc_version(): "--soc_version=Ascend{}".format(get_soc_version()), "--precision_mode=must_keep_origin_dtype", "--input_format=ND", - '--input_shape="input_ids:1,1;attention_mask:{};position_ids:1,1;past_key_values:{}"'.format( - ",".join(attention_mask_shape), ",".join(past_key_values_shape) - ) + '--input_shape="input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}"'.format( + ",".join(input_ids_shape), + ",".join(attention_mask_shape), + ",".join(position_ids_shape), + ",".join(past_key_values_shape) + ), ] +if max_prefill_length > 1: + command_lines.append( + "--dynamic_dims \"{}\"".format(";".join(dynamic_dims)) + ) print("============ run command ==============") print(" ".join(command_lines)) print("=======================================") diff --git a/image/qwen2-1.5b-instruct.gif b/image/qwen2-1.5b-instruct.gif new file mode 100644 index 0000000..d450126 Binary files /dev/null and b/image/qwen2-1.5b-instruct.gif differ diff --git a/utils/engine.py b/utils/engine.py index 2db22f6..4c62b1f 100644 --- a/utils/engine.py +++ b/utils/engine.py @@ -77,6 +77,8 @@ def __init__(self, config: InferenceConfig, context=None,callback=None): self.callback_interval = 1 self.exit_flag = False self.kv_cache = None + self.max_batch = config.max_batch + self.kv_cache_length = config.kv_cache_length self.input_dataset, self.output_dataset = None, None self.inputs:List[Dict[str,]] = [] self.outputs:List[Dict[str,]] = [] @@ -189,47 +191,125 @@ def free_memory(self): ret = acl.rt.free_host(item["buffer_host"]) ret = acl.mdl.destroy_dataset(self.output_dataset) - def inference(self,data) -> List[np.ndarray]: + def inference(self, input_data_list: List[np.ndarray], seq_length=1, is_dynamic=False) -> List[np.ndarray]: """ 执行推理,同步方式 Args: - data (_type_): _description_ + input_data_list (_type_): _description_ + seq_length: 推理长度 Returns: List[np.ndarray]: _description_ """ start = time.time() acl.rt.set_context(self.context) - for i in range(len(data)): + for i in range(len(input_data_list)): if i == 3: - pass + continue else: - bytes_data = data[i].tobytes() + input_data = input_data_list[i] + input_size = input_data.size + input_itemsize = input_data.itemsize + bytes_data = input_data.tobytes() np_ptr = acl.util.bytes_to_ptr(bytes_data) + if is_dynamic: + input_copy_size = input_size * input_itemsize + else: + input_copy_size = self.inputs[i]["size"] ret = acl.rt.memcpy( self.inputs[i]["buffer"], self.inputs[i]["size"], np_ptr, - self.inputs[i]["size"], + input_copy_size, ACL_MEMCPY_HOST_TO_DEVICE ) - check_ret("memcpy", ret) + check_ret("memcpy input", ret) + output_sizes = [] + if is_dynamic: + # link https://www.hiascend.com/doc_center/source/zh/canncommercial/80RC1/apiref/appdevgapi/aclpythondevg_01_0159.html + index, ret = acl.mdl.get_input_index_by_name( + self.model_desc, "ascend_mbatch_shape_data" + ) + check_ret("get_input_index_by_name", ret) + dynamic_dims = [ + # input_ids + self.max_batch, + seq_length, + # attention_mask + self.max_batch, + seq_length + self.kv_cache_length, + # position_ids + self.max_batch, + seq_length + ] + dynamic_dims += self.config.past_key_value_shape + # will set dynamic input shape + ret = acl.mdl.set_input_dynamic_dims( + self.model_id, + self.input_dataset, + index, + { + 'dimCount': len(dynamic_dims), + 'name': '', + 'dims': dynamic_dims + } + ) + check_ret("set_iniput_dynamic_dims", ret) + output_itemsize1 = np.dtype(self.outputs[0]["dtype"]).itemsize + output_itemsize2 = np.dtype(self.outputs[1]["dtype"]).itemsize + logits_size = self.max_batch * seq_length * self.config.vocab_size + logits_itemsize = logits_size * output_itemsize1 + new_kv_cache_size = ( + self.config.num_hidden_layers \ + * 2 \ + * self.max_batch \ + * self.config.num_key_value_heads \ + * seq_length \ + * self.config.per_head_dim \ + ) + new_kv_cache_itemsize = new_kv_cache_size * output_itemsize2 + output_sizes = [logits_size, new_kv_cache_size] + output_itemsizes = [logits_itemsize, new_kv_cache_itemsize] + logits_shape = [self.max_batch, seq_length, self.config.vocab_size] + new_kv_cache_shape = [ + self.config.num_hidden_layers, + 2, + self.max_batch, + self.config.num_key_value_heads, + seq_length, + self.config.per_head_dim + ] + output_shapes = [logits_shape, new_kv_cache_shape] + ret = acl.mdl.execute( self.model_id, self.input_dataset, self.output_dataset ) + check_ret("model_execute", ret) inference_result = [] - for out in self.outputs: + + for output_idx, out in enumerate(self.outputs): + if is_dynamic: + output_itemsize = output_itemsizes[output_idx] + output_size = output_sizes[output_idx] + else: + output_itemsize = out["size"] + output_size = output_itemsize // np.dtype(out["dtype"]).itemsize ret = acl.rt.memcpy( out['buffer_host'], out["size"], out["buffer"], - out["size"], + output_itemsize, ACL_MEMCPY_DEVICE_TO_HOST ) + check_ret("memcpy output", ret) bytes_out = acl.util.ptr_to_bytes(out['buffer_host'], out["size"]) - out_data = np.frombuffer(bytes_out, dtype=out['dtype']) + out_data = np.frombuffer( + bytes_out, + dtype=out['dtype'], + count=output_size, + ).reshape(output_shapes[output_idx]) inference_result.append(out_data) return inference_result diff --git a/utils/inference.py b/utils/inference.py index 7e45732..f2a410a 100644 --- a/utils/inference.py +++ b/utils/inference.py @@ -145,7 +145,12 @@ def stream_predict( max_output_len = self.max_output_length - input_length for i in range(max_output_len): logits = self.session.run(input_ids)[0] - input_ids = self.sample_logits(logits[0][-1:], self.sampling_method, self.sampling_value, self.temperature) + input_ids = self.sample_logits( + logits[0][-1:], + self.sampling_method, + self.sampling_value, + self.temperature + ) input_ids = input_ids.reshape(1, -1) if i == 0: first_token_latency = time.time() - start diff --git a/utils/session.py b/utils/session.py index 436d520..53be187 100644 --- a/utils/session.py +++ b/utils/session.py @@ -2,6 +2,7 @@ from utils.kvcache import create_kv_cache import numpy as np from typing import List +import math import time import sys from utils.engine import ACLModel, init_resource, destroy_resource @@ -112,38 +113,66 @@ def __init__(self, config:InferenceConfig): self.device_id = config.device_id self.context = init_resource(self.device_id) self.model = ACLModel(config, self.context) + self.max_batch = config.max_batch self.input_ids = np.zeros((1,16),dtype=np.int64) self.kv_cache.kv_cache = self.model.kv_cache + self.max_prefill_length = config.max_prefill_length + self.prefill_log2_number = int(math.log2(self.max_prefill_length)) + self.prefill_log2_list = list(range(self.prefill_log2_number, -1, -1)) + self.prefill_log2_list = [2**index for index in self.prefill_log2_list] + def __del__(self): destroy_resource(self.device_id, self.context) + + def decompose_number(self, n, start_index=0): + """ + 将数字n分解成若干个2的指数的和,并返回这些2的指数构成的列表。 + 参数: + n -- 要分解的数字 + 返回: + 分解后的列表,例如 [8, 4] + """ + if n == 0: + return [] + + for i in range(start_index, self.prefill_log2_number + 1): + power = self.prefill_log2_list[i] + if power <= n: + return [power] + self.decompose_number(n - power, i) + return [] + def run(self, input_ids: np.ndarray): seq_len = input_ids.shape[-1] logits = None - for i in range(seq_len): - logits = self.run_one(input_ids[:,i]) + is_dynamic = bool(self.max_prefill_length > 1) + # dynamic inference + if is_dynamic: + seq_list = self.decompose_number(seq_len) + start_i = 0 + for seq in seq_list: + end_i = start_i + seq + logits = self.run_some( + input_ids[:, start_i: end_i], + seq, + is_dynamic + ) + start_i += seq + # static inference + else: + for i in range(seq_len): + logits = self.run_some(input_ids[:,i]) return [logits] - def run_all_logits(self, input_ids: np.ndarray): - seq_len, i = input_ids.shape[-1], 0 - logits = [] - while i < seq_len: - end = i + 16 if i+16 < seq_len else seq_len - cache,mask,pos_ids = self.kv_cache.get_inputs(16) - self.input_ids[0:end-i] = input_ids[i:end] - result:List[np.ndarray] = self.model.inference([self.input_ids, mask, pos_ids, cache]) - self.kv_cache.update(end-i,result[1]) - logits.append(result[0][0:end-i].reshape(1,-1)) - return [np.concatenate(logits).reshape(1,1,-1)] - - def run_one(self, input_ids: np.ndarray): - self.run_times += 1 - cache, mask, pos_ids = self.kv_cache.get_inputs(1) + def run_some(self, input_ids: np.ndarray, seq_length: int = 1, is_dynamic: bool = False): + self.run_times += seq_length + cache, mask, pos_ids = self.kv_cache.get_inputs(seq_length) result:List[np.ndarray] = self.model.inference( - [input_ids, mask, pos_ids, cache] + [input_ids, mask, pos_ids, cache], seq_length, is_dynamic ) - # if self.run_times <= 2: - # print(" == Debug == ") + # if self.run_times <= 20: + # print(" === Debug === ") + # print("run times: ", self.run_times) # logits = result[0] # new_kv_cache = result[1] # print("logits shape: ", logits.shape) @@ -152,5 +181,17 @@ def run_one(self, input_ids: np.ndarray): # print("new_kv_cache: shape", new_kv_cache.shape) # print("new_kv_cache: mean: ", new_kv_cache.astype(np.float32).mean().item()) # print("new_kv_cache: max: ", new_kv_cache.astype(np.float32).max().item()) - self.kv_cache.update(1,result[1]) - return result[0].reshape(1,1,-1) + self.kv_cache.update(seq_length, result[1]) + return result[0].reshape(self.max_batch, seq_length,-1) + + def run_all_logits(self, input_ids: np.ndarray): + seq_len, i = input_ids.shape[-1], 0 + logits = [] + while i < seq_len: + end = i + 16 if i+16 < seq_len else seq_len + cache,mask,pos_ids = self.kv_cache.get_inputs(16) + self.input_ids[0:end-i] = input_ids[i:end] + result:List[np.ndarray] = self.model.inference([self.input_ids, mask, pos_ids, cache]) + self.kv_cache.update(end-i,result[1]) + logits.append(result[0][0:end-i].reshape(1,-1)) + return [np.concatenate(logits).reshape(1,1,-1)] \ No newline at end of file