diff --git a/README.md b/README.md index 324549f..be8e6e1 100644 --- a/README.md +++ b/README.md @@ -61,22 +61,31 @@ cd qwen-ascend-llm pip install -r ./requirements.txt ``` -2. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数。 +2. 导出onnx,当前我设置的kv-cache长度为2048,可以根据自己的内存、显存来设置更大参数,最大则不建议超过`max_position_embeddings`这个数,可以去模型里面的config.json文件里面看,qwen2-1.5B-Instruct里面,这个数值为`32768` ```bash python3 export/export_onnx.py \ --device_str=npu \ --dtype=float16 \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ --onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \ - --kv_cache_length=1024 + --kv_cache_length=2048 ``` 3. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。 + - `--max_input_length`为单次最大可以输入是数据量,该数值必须小于编译onnx的时候指定的`--kv_cache_length` + - `--max_output_length`则必须和之前转onnx的时候指定的`--kv_cache_length`保持一致,否则onnx输出将会异常。 + - 注:最大可以生成token数=`max_output_length`-min(max_input_length, 实际输入的token数) + - npu转出的onnx,dtype取float16,cpu转出来的onnx,dtype取float32 + - `--cpu_thread`根据你的cpu线程数设置,默认取4 ```bash python3 ./cli_chat.py \ --session_type=onnx \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ - --onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" + --onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \ + --dtype="float16" \ + --cpu_thread=4 \ + --max_input_length=1024 \ + --max_output_length=2048 ``` 4. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。 @@ -86,24 +95,38 @@ --output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" ``` -5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为16,如果设置为1,则不会开启动态shape推理功能。该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1` +5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。 + - 运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。 + - kv_cache_length长度和第一步导出onnx时的长度保持一致。 + - `--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为4,如果设置为1,则不会开启动态shape推理功能。 + - 该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1` + - `--kv_cache_length`的数值必须前面转onnx的时候指定的`--kv_cache_length`保持一致,否则大概率会转换失败。 + - `--cpu_thread`为转onnx为om时,开启的cpu线程数,默认为1个线程并行编译,如果内存很多(每个线程单独占用一份内存,所以很费内存),可以调高一些。 ```bash python3 export/onnx2om.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ --onnx_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" \ --om_model_path="./output/model/qwen2_1.5b_chat" \ - --kv_cache_length=1024 \ - --max_prefill_length=16 + --kv_cache_length=2048 \ + --cpu_thread=1 \ + --max_prefill_length=4 ``` ##### 步骤2:在终端运行模型进行对话 -- 使用下面的命令直接运行模型,`--max_prefill_length`需要和上面编译的时候使用的数值相同。 +- 使用下面的命令直接运行模型 + - `--max_prefill_length`需要和上面编译om模型时使用的数值相同。 + - `--max_input_length`为单次最大可以输入是数据量,该数值必须小于编译onnx的时候指定的`--kv_cache_length` + - `--max_output_length`则必须和之前转onnx的时候指定的`--kv_cache_length`保持一致,否则onnx输出将会异常。 + - 注:最大可以生成token数=`max_output_length`-min(max_input_length, 实际输入的token数) ```bash python3 ./cli_chat.py \ + --session_type="acl" \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ --om_model_path="./output/model/qwen2_1.5b_chat.om" \ - --max_prefill_length=16 + --max_input_length=1024 \ + --max_output_length=2048 \ + --max_prefill_length=4 ``` - demo展示1(演示模型,qwen1.5-0.5b-chat,未开启动态shape推理) @@ -119,7 +142,9 @@ python3 ./api.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ --om_model_path="./output/model/qwen2_1.5b_chat.om" \ - --max_prefill_length=16 + --max_input_length=1024 \ + --max_output_length=2048 \ + --max_prefill_length=4 ``` - 进入client目录,可以运行里面的文件请求服务端。 diff --git a/cli_chat.py b/cli_chat.py index 1c5b4ac..e2f7318 100644 --- a/cli_chat.py +++ b/cli_chat.py @@ -20,7 +20,34 @@ def parser_args(): type=str, default="acl", help="acl or onnx", - choices=["acl", "onnx"], + choices=["acl", "onnx", "pytorch"], + ) + parser.add_argument( + "--dtype" , + type=str, + help="support float16/float32, if use CPU, only support fp32", + choices=["float16", "float32"], + default="float32", + ) + parser.add_argument( + "--torch_dtype", + type=str, + help="support float16/float32, if use CPU, only support fp32", + choices=["float16", "float32"], + default="float32", + ) + parser.add_argument( + "--device_str", + type=str, + help="support cpu, cuda, npu, only activate when sesstion_type is pytorch", + choices=["cpu", "cuda", "npu"], + default="cpu", + ) + parser.add_argument( + "--cpu_thread" , + type=int, + help="num of cpu thread when run onnx sesstion", + default=4, ) parser.add_argument( '--onnx_model_path', @@ -44,7 +71,7 @@ def parser_args(): "--max_input_length", help="max input length", type=int, - default=512, + default=1024, ) parser.add_argument( "--max_prefill_length", @@ -53,14 +80,13 @@ def parser_args(): "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " "Note! The higher this number, the longer it will take to compile.", type=int, - default=16, + default=4, ) - parser.add_argument( "--max_output_length", help="max output length (contain input + new token)", type=int, - default=1024, + default=2048, ) return parser.parse_args() @@ -110,12 +136,14 @@ def inference_cli(): hf_model_dir=args.hf_model_dir, om_model_path=args.om_model_path, onnx_model_path=args.onnx_model_path, + cpu_thread=args.cpu_thread, session_type=args.session_type, max_batch=args.max_batch, max_output_length=args.max_output_length, max_input_length=args.max_input_length, kv_cache_length=args.max_output_length, max_prefill_length=max_prefill_length, + dtype=args.dtype ) # main() inference_cli() \ No newline at end of file diff --git a/config.py b/config.py index 694024c..e5f0626 100644 --- a/config.py +++ b/config.py @@ -1,4 +1,5 @@ import os +import torch from transformers.models.qwen2 import Qwen2Config, Qwen2Tokenizer @@ -9,19 +10,21 @@ def __init__( om_model_path: str, onnx_model_path: str, cpu_thread: int = 4, # CPU线程数 - session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language + session_type: str = "acl", # 支持acl和onnx, pytorch三种,acl即Ascend C Language device_id: int = 0, sampling_method: str = "top_p", # 支持 greedy, top_p, top_k sampling_value: float = 0.8, temperature: float = 0.7, max_batch: int = 1, - max_input_length: int = 512, # 输入长度的最大数值 - max_output_length: int = 1024, # 输出长度的最大值 + max_input_length: int = 1024, # 输入长度的最大数值 + max_output_length: int = 2048, # 输出长度的最大值 max_prefill_length: int = 1, # prefile阶段,单次最大推理长度 kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看 - kv_cache_length: int = 1024, # kvcache的最大长度 + kv_cache_length: int = 2048, # kvcache的最大长度 cache_format: str = 'huggingface-tensor', # kv_cache的格式 dtype:str="float16", + torch_dtype: str = "float16", + device_str = "cpu", ): self.tokenizer_dir = hf_model_dir self.session_type = session_type @@ -29,8 +32,11 @@ def __init__( assert os.path.exists(om_model_path), print(om_model_path, "not exists") elif self.session_type == "onnx": assert os.path.exists(onnx_model_path), print(onnx_model_path, "not exists") + elif self.session_type == "pytorch": + assert os.path.exists(hf_model_dir), print(hf_model_dir, "not exists") self.om_model_path = om_model_path self.onnx_model_path = onnx_model_path + self.hf_model_dir = hf_model_dir self.cpu_thread = cpu_thread self.device_id = device_id self.sampling_method = sampling_method @@ -43,6 +49,13 @@ def __init__( self.kv_cache_length = kv_cache_length # max_cache_size self.cache_format = cache_format self.dtype = dtype + if torch_dtype == "float16": + self.torch_dtype = torch.float16 + elif torch_dtype == "float32": + self.torch_dtype = torch.float32 + else: + self.torch_type = "auto" + self.device_str = device_str self.model_config = Qwen2Config.from_pretrained(hf_model_dir) self.num_hidden_layers = self.model_config.num_hidden_layers # n_layer self.num_key_value_heads = self.model_config.num_key_value_heads # head_num diff --git a/export/change_node.py b/export/change_node.py index b9122a3..183f03f 100644 --- a/export/change_node.py +++ b/export/change_node.py @@ -17,6 +17,9 @@ new_onnx_dir = os.path.join(output_dir, "onnx2") if not os.path.exists(new_onnx_dir): os.mkdir(new_onnx_dir) +else: + for file in os.listdir(new_onnx_dir): + os.remove(os.path.join(new_onnx_dir, file)) now_dir = os.path.dirname(os.path.abspath(__file__)) project_dir = os.path.dirname(now_dir) diff --git a/export/export_onnx.py b/export/export_onnx.py index 3d6f7e4..fc3fedb 100644 --- a/export/export_onnx.py +++ b/export/export_onnx.py @@ -64,7 +64,7 @@ def parser_arguments(): "--kv_cache_length", help="kv-cache length", type=int, - default=1024, + default=2048, ) return parser.parse_args() @@ -123,7 +123,7 @@ def export_onnx( output_names = ["logits", "out_key_values"] dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_length"}, - "attention_mask": {0: "batch_size", 1: "seq_length+kv_len"}, + "attention_mask": {0: "batch_size", 1: "seq_length + kv_len"}, "position_ids": {0: "batch_size", 1: "seq_length"}, "past_key_values": {0: "batch_size", 1: "kv_len"}, } diff --git a/export/onnx2om.py b/export/onnx2om.py index fd9ff40..b151e66 100644 --- a/export/onnx2om.py +++ b/export/onnx2om.py @@ -48,6 +48,12 @@ type=int, default=1, ) +parser.add_argument( + "--cpu_thread" , + type=int, + help="num of cpu thread when convert onnx to om", + default=1, +) parser.add_argument( "--max_prefill_length", help="max prefill length in first inference. " @@ -55,13 +61,13 @@ "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " "Note! The higher this number, the longer it will take to compile.", type=int, - default=16 + default=8 ) parser.add_argument( "--kv_cache_length", help="kv-cache length", type=int, - default=1024, + default=2048, ) @@ -114,10 +120,10 @@ def get_soc_version(): assert (max_prefill_length < kv_cache_length), \ print("max_input_length max be smaller than kv_cache_length, because max_input_length + max_output_length <= kv_cache") input_ids_length_range = prefill_length_range -attention_length_range = [ - length + kv_cache_length - for length in prefill_length_range -] +# attention_length_range = [ +# length + kv_cache_length +# for length in prefill_length_range +# ] position_length_range = prefill_length_range input_ids_shape = [ f"1~{max_batch}" if max_batch > 1 else "1", @@ -132,14 +138,21 @@ def get_soc_version(): "-1" if max_prefill_length > 1 else "1" ] dynamic_dims = [] -for dynamic_dim in zip( - input_ids_length_range, attention_length_range, position_length_range -): - dynamic_dim = [str(dim) for dim in dynamic_dim] - dynamic_dims.append(",".join(dynamic_dim)) +for dynamic_kv_cache_length in [ + kv_cache_length // 2, + kv_cache_length +]: + for dynamic_dim in zip(input_ids_length_range, position_length_range): + new_dynamic_dim = [ + str(dynamic_dim[0]), # input_ids + str(dynamic_dim[0] + dynamic_kv_cache_length), # attention_mask_shape + str(dynamic_dim[1]), # position_ids + str(dynamic_kv_cache_length), # past_key_values + ] + dynamic_dims.append(",".join(new_dynamic_dim)) past_key_values_shape = [ f"1~{max_batch}" if max_batch > 1 else "1", - kv_cache_length, + "-1" if max_prefill_length > 1 else kv_cache_length, num_hidden_layers * 2 * num_key_value_heads, per_head_dim ] @@ -152,8 +165,14 @@ def get_soc_version(): else: soc_version = args.soc_version command_lines = [ + # reduce memory useage + "export MS_DEV_FORCE_ACL=1 && ", + "export MS_ENABLE_GE=1 && ", + "export TE_PARALLEL_COMPILER={} &&".format(args.cpu_thread), + "export MAX_COMPILE_CORE_NUMBER={} &&".format(args.cpu_thread), "atc", "--framework=5", + "--host_env_cpu=aarch64", '--model="{}"'.format(args.onnx_model_path), '--output="{}"'.format(args.om_model_path), "--soc_version={}".format(soc_version), diff --git a/export/test_onnx_run.py b/export/test_onnx_run.py index 065d080..04a37c0 100644 --- a/export/test_onnx_run.py +++ b/export/test_onnx_run.py @@ -1,4 +1,5 @@ import os +import time import numpy as np import onnxruntime import argparse @@ -28,6 +29,12 @@ type=str, default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx") ) +parser.add_argument( + "--kv_cache_length", + help="kv-cache length", + type=int, + default=2048, +) args = parser.parse_args() if args.dtype == "float16": @@ -38,7 +45,7 @@ raise Exception("not support dtype, only support float16/float32") -def create_kv_cache(config: Qwen2Config, kv_cache_length=1024): +def create_kv_cache(config: Qwen2Config, kv_cache_length=args.kv_cache_length): return np.zeros( [ 1, @@ -50,7 +57,7 @@ def create_kv_cache(config: Qwen2Config, kv_cache_length=1024): ) -def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = 1024): +def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = args.kv_cache_length): """ 获取指定长度的kv_cache, 顺便生成mask和position_id Args: @@ -122,12 +129,15 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size print("now_kv_cache shape: ", now_kv_cache.shape) print("attention_mask shape: ", attn_mask.shape) print("position_ids shape: ", position_ids.shape) +st = time.time() outputs = llm_session.run(None, { "input_ids": input_ids[:, :1], "attention_mask": attn_mask, "position_ids": position_ids, "past_key_values": now_kv_cache, }) +et = time.time() +print("duration: ", et - st) print("==== onnx runtime ====") print("output length: ", len(outputs)) logits = outputs[0] diff --git a/export/test_pytorch_run.py b/export/test_pytorch_run.py index 943e4d8..129149a 100644 --- a/export/test_pytorch_run.py +++ b/export/test_pytorch_run.py @@ -1,4 +1,5 @@ import os +import time import torch import argparse from modeling_qwen2 import Qwen2ForCausalLM @@ -28,6 +29,12 @@ help="model and tokenizer path, only support huggingface model", default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") ) +parser.add_argument( + "--kv_cache_length", + help="kv-cache length", + type=int, + default=2048, +) args = parser.parse_args() @@ -42,7 +49,7 @@ raise Exception("not support dtype, only support float16/float32") -def create_kv_cache(config: Qwen2Config, kv_cache_length=1024): +def create_kv_cache(config: Qwen2Config, kv_cache_length: int = args.kv_cache_length): return torch.zeros( [ 1, @@ -54,7 +61,7 @@ def create_kv_cache(config: Qwen2Config, kv_cache_length=1024): ).to(device_str) -def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = 1024): +def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = args.kv_cache_length): """ 获取指定长度的kv_cache, 顺便生成mask和position_id Args: @@ -115,6 +122,7 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size print("now_kv_cache shape: ", now_kv_cache.shape) print("attention_mask shape: ", attn_mask.shape) print("position_ids shape: ", position_ids.shape) +st = time.time() outputs = model.forward( input_ids[:, :1], attn_mask, @@ -123,6 +131,8 @@ def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size # use_cache=True, # output_attentions=True, ) +et = time.time() +print("duration: ", et - st) print("==== pytorch runtime ====") print("output length: ", len(outputs)) logits = outputs[0] # 1: -0.10800 diff --git a/utils/engine.py b/utils/engine.py index b4f44af..ebfb697 100644 --- a/utils/engine.py +++ b/utils/engine.py @@ -3,6 +3,7 @@ import acl import numpy as np import os +import gc from functools import reduce from operator import mul import ctypes @@ -43,7 +44,6 @@ def mmap_file(file_path): # 关闭文件描述符,映射区域仍然有效 os.close(file_descriptor) - # 返回映射区域的地址 return mapped_memory,file_size @@ -52,18 +52,27 @@ def check_ret(str,ret): print(f"return code is {ret}, detail: {str}",flush=True) def init_resource(device_id: int): + print("[INFO] acl init") ret = acl.init() check_ret("init", ret) + print(f"[INFO] acl set device, device_id: {device_id}") ret = acl.rt.set_device(device_id) check_ret("set_device", ret) - context,ret = acl.rt.create_context(device_id) + print(f"[INFO] acl create context") + context, ret = acl.rt.create_context(device_id) check_ret("create_context", ret) return context def destroy_resource(device_id: int, context): + print("[INFO] acl reset device") ret = acl.rt.reset_device(device_id) + check_ret("reset device", ret) + print("[INFO] acl finalize") ret = acl.finalize() + check_ret("finalize", ret) + print("[INFO] destory context") ret = acl.rt.destroy_context(context) + check_ret("destory context", ret) dtype2NpType = {0:np.float32,1:np.float16,2:np.int8,3:np.int32,9:np.int64} @@ -88,13 +97,18 @@ def __init__(self, config: InferenceConfig, context=None,callback=None): self.inputs:List[Dict[str,]] = [] self.outputs:List[Dict[str,]] = [] self.config = config + self.past_key_value_shape = config.past_key_value_shape + self.half_past_key_value_shape = list(config.past_key_value_shape) + self.half_past_key_value_shape[1] = self.half_past_key_value_shape[1] // 2 self.load_model(config.om_model_path) self.allocate_memory() if not callback: return self.stream, ret = acl.rt.create_stream() + check_ret("create stream", ret) self.tid, ret = acl.util.start_thread(self._process_callback, [self.context, 50]) + check_ret("start thread", ret) check_ret("acl.util.start_thread", ret) ret = acl.rt.subscribe_report(self.tid, self.stream) check_ret("acl.rt.subscribe_report", ret) @@ -117,8 +131,14 @@ def get_inputs(self, seq_len: int) -> List[np.ndarray]: self.per_head_dim ) """ - mask = np.ones((1,self.past_kv_size + seq_len), dtype=np.int64) - mask[:, self.real_kv_size: self.past_kv_size] = 0 + temp_seq_len = self.real_kv_size + seq_len + if temp_seq_len <= self.kv_cache_length // 2: + temp_kv_size = self.kv_cache_length // 2 + else: + temp_kv_size = self.kv_cache_length + + mask = np.ones((1, temp_kv_size + seq_len), dtype=np.int64) + mask[:, self.real_kv_size: temp_kv_size] = 0 pos_id =np.arange( self.input_pos, self.input_pos + seq_len, @@ -186,12 +206,61 @@ def load_model(self, model_path): Args: model_path (_type_): _description_ """ - model_add, model_size = mmap_file(model_path) - self.model_id, ret = acl.mdl.load_from_mem(model_add, model_size) - - #self.model_id, ret = acl.mdl.load_from_file(model_path) + # 方法1:通过map的方式加载,大概24秒 + # model_add, model_size = mmap_file(model_path) + # self.model_id, ret = acl.mdl.load_from_mem(model_add, model_size) + # check_ret("load model",ret) + # munmap_func(model_add, model_size) + # 方法2:直接加载model,用时34秒 + # self.model_id, ret = acl.mdl.load_from_file(model_path) + # check_ret("load model",ret) + # 方法3:将模型加载到device内存中 + # 先获取模型大小 + model_buffer_size = os.path.getsize(model_path) + # 分配模型buffer到device内存中 + model_buffer, ret = acl.rt.malloc(model_buffer_size, ACL_MEM_MALLOC_HUGE_FIRST) + p_model_buffer = model_buffer + check_ret("alloc model buffer",ret) + # 分块读取模型文件,然后将其拷贝到device model中 + # 块大小(例如 50MB) + chunk_size = 50 * 1024 * 1024 + have_load_size = 0 + with open(model_path, 'rb') as file: + while True: + # 读取一块数据 + chunk = file.read(chunk_size) + chunk_bytes = len(chunk) + # 如果读取的数据为空,说明已经读取完毕 + if not chunk: + break + # 获取这块数据的内存地址 + writable_buffer = ctypes.create_string_buffer(chunk) + chunk_address = ctypes.addressof(writable_buffer) + ret = acl.rt.memcpy( + p_model_buffer, + model_buffer_size - have_load_size, + chunk_address, + chunk_bytes, + ACL_MEMCPY_HOST_TO_DEVICE + ) + del writable_buffer + check_ret("memcpy input", ret) + progress = have_load_size * 100 / model_buffer_size + print(f"\r[INFO] load model buffer {progress:.2f}%", end="") + have_load_size += chunk_bytes + p_model_buffer += chunk_bytes + print("\r[INFO] load model buffer 100.00%") + gc.collect() + st = time.time() + print("[INFO] load model from memory, please wait a monment...") + self.model_id, ret = acl.mdl.load_from_mem(model_buffer, model_buffer_size) check_ret("load model",ret) - munmap_func(model_add, model_size) + et = time.time() + # 模型加载完后,model_buffer实测可以清理掉了,节省大量空间 + ret = acl.rt.free(model_buffer) + check_ret(f"free model buffer device memory", ret) + print("[INFO] load model duration: ", et - st) + print("[INFO] get model desc") self.model_desc = acl.mdl.create_desc() ret = acl.mdl.get_desc(self.model_desc, self.model_id) check_ret("get model desc",ret) @@ -255,6 +324,7 @@ def free_memory(self): """ 释放内存 """ + print("[INFO] free input and output buffer") for i, item in enumerate(self.input_data): ret = acl.rt.free(item["buffer"]) check_ret(f"free input[{i}] device memory",ret) @@ -307,18 +377,34 @@ def inference(self, input_data_list: List[np.ndarray], seq_length=1, is_dynamic= self.model_desc, "ascend_mbatch_shape_data" ) check_ret("get_input_index_by_name", ret) - dynamic_dims = [ - # input_ids - self.max_batch, - seq_length, - # attention_mask - self.max_batch, - seq_length + self.kv_cache_length, - # position_ids - self.max_batch, - seq_length - ] - dynamic_dims += self.config.past_key_value_shape + # 新逻辑,将kv_cache_length切成两片 + if (self.real_kv_size + seq_length) > self.kv_cache_length // 2: + dynamic_dims = [ + # input_ids + self.max_batch, + seq_length, + # attention_mask + self.max_batch, + seq_length + self.kv_cache_length, + # position_ids + self.max_batch, + seq_length + ] + dynamic_dims += self.past_key_value_shape + else: + dynamic_dims = [ + # input_ids + self.max_batch, + seq_length, + # attention_mask + self.max_batch, + seq_length + self.kv_cache_length // 2, + # position_ids + self.max_batch, + seq_length + ] + dynamic_dims += self.half_past_key_value_shape + # will set dynamic input shape ret = acl.mdl.set_input_dynamic_dims( self.model_id, diff --git a/utils/inference.py b/utils/inference.py index b5e05cb..e7ddb3d 100644 --- a/utils/inference.py +++ b/utils/inference.py @@ -8,11 +8,12 @@ from utils.session import Session from config import InferenceConfig from tqdm import trange, tqdm +import torch class Inference: - def __init__(self, config:InferenceConfig) -> None: + def __init__(self, config: InferenceConfig) -> None: self.max_input_length = config.max_input_length self.max_output_length = config.max_output_length # self.tokenizer=Tokenizer(config.tokenizer) @@ -22,7 +23,16 @@ def __init__(self, config:InferenceConfig) -> None: self.sampling_method = config.sampling_method self.sampling_value = config.sampling_value self.temperature = config.temperature - self.session=Session.fromConfig(config) + self.session = Session.fromConfig(config) + self.session_type = config.session_type + if config.device_str == "cpu": + self.torch_device = torch.device("cpu") + elif config.device_str == "cuda": + self.torch_device = torch.device("cuda") + elif config.device_str == "npu": + self.torch_device = torch.device("npu") + else: + raise Exception(f"unsport device {config.device_str}") # self.prompt=config.prompt self.kv_cache_length = config.kv_cache_length self.state: dict = {"code":200,"isEnd":False,"message":""} @@ -30,7 +40,7 @@ def __init__(self, config:InferenceConfig) -> None: self.lock = Lock() self.first = True # self.stop_mp = {"[|Human|]":6,"[|AI|]":5,"<|assistant|>":6,"<|user|>":5} - print("init success") + print("[INFO] init success") def generate_cache(self, prompt: str): @@ -141,9 +151,16 @@ def stream_predict( tokenize=False, add_generation_prompt=True ) - input_ids = self.tokenizer( - [text], return_tensors="np" - )["input_ids"].astype(np.int64).reshape(1, -1) + if self.session_type in ["onnx", "acl"]: + input_ids = self.tokenizer( + [text], return_tensors="np" + )["input_ids"].astype(np.int64).reshape(1, -1) + elif self.session_type == "pytorch": + input_ids = self.tokenizer( + [text], return_tensors="pt" + )["input_ids"].to(torch.long).reshape(1, -1).to(self.torch_device) + else: + raise Exception(f"unknown session_type {self.session_type}") input_ids = input_ids[:, -self.max_input_length:] # print("input_ids shape: ", input_ids.shape) self.first = False @@ -240,9 +257,16 @@ def predict( tokenize=False, add_generation_prompt=True ) - input_ids = self.tokenizer( - [text], return_tensors="np" - )["input_ids"].astype(np.int64).reshape(1, -1) + if self.session_type in ["onnx", "acl"]: + input_ids = self.tokenizer( + [text], return_tensors="np" + )["input_ids"].astype(np.int64).reshape(1, -1) + elif self.session_type == "pytorch": + input_ids = self.tokenizer( + [text], return_tensors="pt" + )["input_ids"].to(torch.long).reshape(1, -1).to(self.torch_device) + else: + raise Exception(f"unknown session_type {self.session_type}") input_ids = input_ids[:, -self.max_input_length:] self.first = False ids_list = [] diff --git a/utils/kvcache.py b/utils/kvcache.py index ec78590..8f29eac 100644 --- a/utils/kvcache.py +++ b/utils/kvcache.py @@ -1,4 +1,5 @@ import numpy as np +import torch from typing import Optional,Tuple,List from config import InferenceConfig # 对KV缓存和输出输出格式进行管理 @@ -6,6 +7,7 @@ class KVCacheManger: def __init__(self, config: InferenceConfig) -> None: self.num_key_value_heads = config.num_key_value_heads # head len self.kv_cache_length = config.kv_cache_length # max_size + self.session_type = config.session_type self.input_pos = 0 self.past_kv_size = 0 self.num_hidden_layers = config.num_hidden_layers # n_layer @@ -14,21 +16,35 @@ def __init__(self, config: InferenceConfig) -> None: self.per_head_dim = config.per_head_dim # head_dim self.past_key_value_shape = config.past_key_value_shape self.real_kv_size = 0 # 真正的kv_cache长度 - if config.dtype == "float16": - self.dtype=np.float16 - elif config.dtype=="float32": - self.dtype=np.float32 + self.device_str = config.device_str + if self.session_type == "onnx": + if config.dtype == "float16": + self.dtype=np.float16 + elif config.dtype=="float32": + self.dtype=np.float32 + else: + raise Exception("only support float16 and float32, not ", config.dtype) + elif self.session_type == "pytorch": + if config.dtype == "float16": + self.dtype=torch.float16 + elif config.dtype=="float32": + self.dtype=torch.float32 + else: + raise Exception("only support float16 and float32, not ", config.dtype) + if self.session_type == "onnx": + self.kv_cache = np.zeros(self.past_key_value_shape, dtype=self.dtype) + elif self.session_type == "pytorch": + self.kv_cache = torch.zeros(self.past_key_value_shape, dtype=self.dtype, device=self.device_str) else: - raise Exception("only support float16 and float32, not ", np.dtype) - # self.kv_cache = None - self.kv_cache = np.zeros(self.past_key_value_shape, dtype=self.dtype) + self.kv_cache = None - def create_empty_cache(self): - """ - 创建空的kv_cache - """ - if self.cache_format == "huggingface-tensor": - self.kv_cache = np.zeros(self.past_key_value_shape, dtype=self.dtype) + + # def create_empty_cache(self): + # """ + # 创建空的kv_cache + # """ + # if self.cache_format == "huggingface-tensor": + # self.kv_cache = np.zeros(self.past_key_value_shape, dtype=self.dtype) def update( self, @@ -47,7 +63,7 @@ def update( def get_inputs(self, seq_len: int) -> List[np.ndarray]: """ - 获取指定长度的kv_cache, 顺便生成mask和position_id + 获取指定长度的kv_cache, 顺便生成mask和position_id,仅用于onnx Args: seq_len (int): 待获取的kv-cache长度 @@ -65,14 +81,26 @@ def get_inputs(self, seq_len: int) -> List[np.ndarray]: self.per_head_dim ) """ - cache = self.kv_cache[:, :self.past_kv_size] - mask = np.ones((1,self.past_kv_size + seq_len), dtype=np.int64) - mask[:, self.real_kv_size: self.past_kv_size] = 0 - pos_id =np.arange( - self.input_pos, - self.input_pos + seq_len, - dtype=np.int64 - ).reshape(1,-1) + # 因为onnx支持动态库,所以取实际大小的kv-cache, +1是防止cache为空 + temp_kv_size = self.real_kv_size + 1 + cache = self.kv_cache[:, :temp_kv_size] + if self.session_type == "onnx": + mask = np.ones((1, temp_kv_size + seq_len), dtype=np.int64) + mask[:, self.real_kv_size: temp_kv_size] = 0 + pos_id =np.arange( + self.input_pos, + self.input_pos + seq_len, + dtype=np.int64 + ).reshape(1,-1) + elif self.session_type == "pytorch": + mask = torch.ones((1, temp_kv_size + seq_len), dtype=torch.long, device=self.device_str) + mask[:, self.real_kv_size: temp_kv_size] = 0 + pos_id =torch.arange( + self.input_pos, + self.input_pos + seq_len, + dtype=torch.long, + device=self.device_str + ).reshape(1,-1) return cache, mask, pos_id def reset(self,num=1): diff --git a/utils/session.py b/utils/session.py index 1ded186..cd8e088 100644 --- a/utils/session.py +++ b/utils/session.py @@ -5,7 +5,6 @@ import math import time import sys -from utils.engine import ACLModel, init_resource, destroy_resource import onnxruntime as ort from tqdm import tqdm, trange @@ -23,6 +22,8 @@ def fromConfig(config:InferenceConfig) -> 'Session': return OnnxSession(config) elif config.session_type=='acl': return AclSession(config) + elif config.session_type == "pytorch": + return PyTorchSession(config) else: return None @@ -63,6 +64,28 @@ def run(self, input_ids:np.ndarray, show_progress=False): }) self.kv_cache.update(seq_len,result[1]) return result[0] + +class PyTorchSession(Session): + def __init__(self, config:InferenceConfig)->None: + super().__init__(config) + self.kv_cache = create_kv_cache(config) + from export.modeling_qwen2 import Qwen2ForCausalLM + self.device_str = config.device_str + self.model = Qwen2ForCausalLM.from_pretrained( + config.hf_model_dir, + torch_dtype=config.torch_dtype + ).to(config.device_str) + + def run(self, input_ids:np.ndarray, show_progress=False): + seq_len=input_ids.shape[-1] + cache, mask, pos_ids = self.kv_cache.get_inputs(seq_len) + # print("input_ids shape/dtype: ", input_ids.shape, input_ids.dtype) + # print("cache shape/dtype: ", cache.shape, cache.dtype) + # print("mask shape/dtype: ", mask.shape, mask.dtype) + # print("pos_ids shape/dtype: ", pos_ids.shape, pos_ids.dtype) + result = self.model(input_ids, mask, pos_ids, cache) + self.kv_cache.update(seq_len, result[1]) + return result[0] # onnxruntime-cann is preview, not work now """ @@ -116,6 +139,7 @@ class AclSession(Session): context = None def __init__(self, config:InferenceConfig): super().__init__(config) + from utils.engine import ACLModel, init_resource self.device_id = config.device_id self.context = init_resource(self.device_id) self.model = ACLModel(config, self.context) @@ -132,6 +156,7 @@ def reset(self): self.model.reset(); def __del__(self): + from utils.engine import destroy_resource destroy_resource(self.device_id, self.context) def decompose_number(self, n, start_index=0): @@ -196,6 +221,11 @@ def run_some( ): self.run_times += seq_length mask, pos_ids = self.model.get_inputs(seq_length) + # print("=========================") + # print("input_ids: ", input_ids) + # print("attention_mask: ", mask) + # print("position_ids: ", pos_ids) + # print("=========================") logits = self.model.inference( [input_ids, mask, pos_ids], seq_length, is_dynamic, is_prefill=is_prefill )