diff --git a/README.md b/README.md index 3a08305..5ee85cf 100644 --- a/README.md +++ b/README.md @@ -13,38 +13,46 @@ ### 分步骤运行 ##### 步骤1:编译模型 -1. 进入export文件夹 +1. 进入export文件夹, 导出onnx。 ```bash cd export + python3 export_onnx.py --hf_model_dir="download/[你下载的模型路径]" + cd.. ``` -2. 导出onnx。 + +2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常。 ```bash - python3 export_onnx.py --hf_model_dir="download/[你下载的模型路径]" + python3 ./cli_chat.py --session_type=onnx ``` -3. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,需要改一下结构。 +3. 进入export文件夹,改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。 ```bash + cd export python3 change_node.py + cd .. ``` 4. 转onnx为om模型 ```bash + cd export python3 onnx2om.py --hf_model_dir="download/[你下载的模型路径]" - ``` - -5. 返回上层路径 - ```bash cd .. ``` + ##### 步骤2:运行模型 +- 使用下面的命令直接运行模型 + ```bash + python3 ./cli_chat.py --hf_model_dir="download/[你下载的模型路径]" + ``` ### 当前功能 - [x] 导出onnx, om模型 -- [ ] 模型推理 -- [ ] 流式传输 +- [x] 模型推理,支持onnx推理。 +- [ ] 模型推理,支持acl推理。 +- [x] 流式传输 - [ ] 兼容OpenAI的api搭建 - [ ] 支持functional call - [ ] 支持模型量化,如weight only, smooth quant等 diff --git a/cli_chat.py b/cli_chat.py new file mode 100644 index 0000000..aede140 --- /dev/null +++ b/cli_chat.py @@ -0,0 +1,98 @@ +import sys +import argparse +from concurrent.futures import ThreadPoolExecutor +from config import InferenceConfig +from utils.inference import Inference +import os + +project_dir = os.path.dirname(os.path.abspath(__file__)) +parser = argparse.ArgumentParser() +parser.add_argument( + '--hf_model_dir', + type=str, + help="model and tokenizer path, only support huggingface model", + default=os.path.join(project_dir, "download", "Qwen1_5_0_5B_Chat") +) +parser.add_argument( + "--session_type", + type=str, + default="acl", + help="acl or onnx", + choices=["acl", "onnx"], +) +parser.add_argument( + '--onnx_model_path', + type=str, + help="onnx_model_path", + default=os.path.join(project_dir, "output", "onnx", "qwen1.5_0.5b_chat.onnx") +) +parser.add_argument( + "--om_model_path", + help="mindspore model path", + type=str, + default= os.path.join(project_dir, "output", "model", "qwen1.5_0.5b_chat.om") +) +parser.add_argument( + "--max_input_length", + help="max input length", + type=int, + default=512, +) + +parser.add_argument( + "--max_output_length", + help="max output length (contain input + new token)", + type=int, + default=1024, +) + +args = parser.parse_args() +config = InferenceConfig( + hf_model_dir=args.hf_model_dir, + om_model_path=args.om_model_path, + onnx_model_path=args.onnx_model_path, + session_type=args.session_type, + max_output_length=args.max_output_length, + max_input_length=args.max_input_length, + kv_cache_length=args.max_output_length, +) +infer_engine=Inference(config) + +def inference_cli(): + print("\n欢迎使用Qwen聊天机器人,输入exit或者quit退出,输入clear清空历史记录") + history = [] + while True: + input_text = input("Input: ") + if input_text in ["exit", "quit", "exit()", "quit()"]: + break + if input_text == 'clear': + history = [] + print("Output: 已清理历史对话信息。") + continue + print("Output: ", end='') + response = "" + is_first = True + first_token_lantency, decode_speed = 0, 0 + for ( + new_text, + first_token_lantency, + decode_speed, + total_speed + ) in infer_engine.stream_predict(input_text, history=history): + if is_first: + if len(new_text.strip()) == 0: + continue + is_first = False + print(new_text, end='', flush=True) + response += new_text + print("") + print( + "[INFO] first_token_lantency: {:.4f}s,".format(first_token_lantency), + " decode_speed: {:.2f} token/s, ".format(decode_speed), + " total_speed(prefill+decode): {:.2f} token/s".format(total_speed), + ) + + history.append({"role": "assistant", "content": response}) +if __name__ == '__main__': + # main() + inference_cli() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..5181cbd --- /dev/null +++ b/config.py @@ -0,0 +1,54 @@ +import os +from transformers.models.qwen2 import Qwen2Config, Qwen2Tokenizer + + +class InferenceConfig: + def __init__( + self, + hf_model_dir: str, + om_model_path: str, + onnx_model_path: str, + session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language + device_id: int = 0, + sampling_method: str = "top_k", + sampling_value: float = 10, + temperature: float = 0.7, + max_input_length: int = 512, # 输入长度的最大数值 + max_output_length: int = 1024, # 输出长度的最大值 + kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看 + kv_cache_length: int = 1024, # kvcache的最大长度 + cache_format: str = 'huggingface-tensor', # kv_cache的格式 + dtype:str="float16", + ): + self.tokenizer_dir = hf_model_dir + self.session_type = session_type + if self.session_type == "acl": + assert os.path.exists(om_model_path), print(om_model_path, "not exists") + elif self.session_type == "onnx": + assert os.path.exists(onnx_model_path), print(onnx_model_path, "not exists") + self.om_model_path = om_model_path + self.onnx_model_path = onnx_model_path + self.device_id = device_id + self.sampling_method = sampling_method + self.sampling_value = sampling_value + self.temperature = temperature + self.max_input_length = max_input_length + self.max_output_length = max_output_length + self.kvcache_method = kvcache_method + self.kv_cache_length = kv_cache_length # max_cache_size + self.cache_format = cache_format + self.dtype = dtype + self.model_config = Qwen2Config.from_pretrained(hf_model_dir) + self.num_hidden_layers = self.model_config.num_hidden_layers # n_layer + self.num_key_value_heads = self.model_config.num_key_value_heads # head_num + self.hidden_size = self.model_config.hidden_size # hidden_dim + self.num_attention_heads = self.model_config.num_attention_heads + self.per_head_dim = self.hidden_size // self.num_attention_heads # head_dim + self.past_key_value_shape = ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) diff --git a/export/change_node.py b/export/change_node.py index 16d505a..02417ec 100644 --- a/export/change_node.py +++ b/export/change_node.py @@ -32,6 +32,7 @@ if node.op_type == "Trilu": new_node = helper.make_node( "Trilu", + name="MY_" + node.name, inputs=[node.input[0]], outputs=node.output, upper=0 diff --git a/export/export_onnx.py b/export/export_onnx.py index 4801025..53132b8 100644 --- a/export/export_onnx.py +++ b/export/export_onnx.py @@ -16,9 +16,6 @@ import io import argparse -device_str = "npu" -if device_str == "npu": - import torch_npu now_dir = os.path.dirname(os.path.abspath(__file__)) project_dir = os.path.dirname(now_dir) @@ -32,6 +29,20 @@ def parser_arguments(): parser = argparse.ArgumentParser() + parser.add_argument( + "--device_str", + type=str, + choices=["npu", "cuda", "cpu"], + help="support npu, cuda, cpu", + default="npu", + ) + parser.add_argument( + "--dtype" , + type=str, + help="support float16/float32, if use CPU, only support fp32", + choices=["float16", "float32"], + default="float16", + ) parser.add_argument( '--hf_model_dir', type=str, @@ -54,17 +65,29 @@ def parser_arguments(): def export_onnx( - base_model: str, - output_path: str, - kv_cache_length: int, - num_hidden_layers: int, - num_key_value_heads: int, - per_head_dim: int, + device_str, + dtype: str, + hf_model_dir: str, + onnx_model_path: str, + kv_cache_length: int, + num_hidden_layers: int, + num_key_value_heads: int, + per_head_dim: int, ): + if device_str == "npu": + import torch_npu + if dtype == "float16": + assert device_str.lower() != "cpu", print("cpu not support fp16") + torch_dtype = torch.float16 + elif dtype == "float32": + torch_dtype = torch.float32 + else: + raise Exception("unsupport dtype") + device = torch.device(device_str) model = Qwen2ForCausalLM.from_pretrained( - base_model, - torch_dtype=torch.float16, + hf_model_dir, + torch_dtype=torch_dtype, trust_remote_code=True ).to(device) quantize_cfg = { @@ -95,8 +118,8 @@ def export_onnx( output_names = ["logits", "out_key_values"] dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_length"}, + "attention_mask": {0: "batch_size", 1: "seq_length+kv_len"}, "position_ids": {0: "batch_size", 1: "seq_length"}, - "attention_mask": {0: "batch_size", 1: "all_len"}, "past_key_values": {2: "batch_size", 4: "kv_len"}, } batch_size = 1 @@ -115,7 +138,7 @@ def export_onnx( kv_cache_length, per_head_dim ), - dtype=torch.float16 + dtype=torch_dtype ).to(device) input_args = ( input_ids, @@ -128,7 +151,6 @@ def export_onnx( True, # output_attentions: Optional[bool] = None, None, # output_hidden_states False # return_dict: - ) model.eval() with torch.no_grad(): @@ -137,11 +159,11 @@ def export_onnx( # print(model) torch.onnx.export( model, - f=output_path, + f=onnx_model_path, args=input_args, input_names=input_names, output_names=output_names, - #dynamic_axes=dynamic_axes, + dynamic_axes=dynamic_axes, do_constant_folding=False, opset_version=14, export_params=True @@ -178,9 +200,11 @@ def export_onnx( print("new model config save ok in ", args.hf_model_dir) print("begin export onnx") export_onnx( - args.hf_model_dir, - args.onnx_model_path, - args.kv_cache_length, + device_str=args.device_str, + dtype=args.dtype, + hf_model_dir=args.hf_model_dir, + onnx_model_path=args.onnx_model_path, + kv_cache_length=args.kv_cache_length, num_hidden_layers=num_hidden_layers, num_key_value_heads=num_key_value_heads, per_head_dim=per_head_dim diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index 38c72d2..5264d21 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -700,15 +700,16 @@ def forward( query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() - + # copy from chatglm3-6b + # attention_mask = ~attention_mask attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + # dropout_p=self.attention_dropout if self.training else 0.0, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, + # is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -736,8 +737,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int): f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." ) - self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - + # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Qwen2SdpaAttention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -950,6 +951,28 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value + @staticmethod + def get_masks(input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, + device=input_ids.device) + full_attention_mask.tril_() + past_length = past_key_values.shape[4] + # if past_length is not None: + full_attention_mask = torch.cat( + (torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), + dim=-1 + ) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze( + 1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) def forward( self, @@ -1009,7 +1032,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + """ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1033,6 +1056,7 @@ def forward( ) else: # 4d mask is passed through the layers + # [1, 1, 2, 1026], value=-65504 attention_mask = _prepare_4d_causal_attention_mask( attention_mask, (batch_size, seq_length), @@ -1040,7 +1064,17 @@ def forward( past_key_values_length, sliding_window=self.config.sliding_window, ) - + """ + # copy from chatglm3-6b + full_attention_mask = self.get_masks( + input_ids, + past_key_values, + attention_mask, + ) + dtype = past_key_values.dtype + device = input_ids.device + attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device) + attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min) hidden_states = inputs_embeds # decoder layers diff --git a/export/onnx2om.py b/export/onnx2om.py index 97e0fc0..6a967f1 100644 --- a/export/onnx2om.py +++ b/export/onnx2om.py @@ -29,8 +29,8 @@ default=os.path.join(onnx_model_dir, "qwen1.5_0.5b_chat.onnx") ) parser.add_argument( - "--ms_model_path", - help="mindspore model path", + "--om_model_path", + help=".om model path", type=str, default= os.path.join(model_dir, "qwen1.5_0.5b_chat") ) @@ -93,7 +93,7 @@ def get_soc_version(): "atc", "--framework=5", '--model="{}"'.format(args.onnx_model_path), - '--output="{}"'.format(args.ms_model_path), + '--output="{}"'.format(args.om_model_path), "--soc_version=Ascend{}".format(get_soc_version()), "--precision_mode=must_keep_origin_dtype", "--input_format=ND", diff --git a/utils/engine.py b/utils/engine.py new file mode 100644 index 0000000..7375834 --- /dev/null +++ b/utils/engine.py @@ -0,0 +1,299 @@ +import time +from typing import Dict, List +import acl +import numpy as np +import os +import ctypes +from config import InferenceConfig +from ctypes import c_void_p, c_int, c_size_t, c_ulong, c_int64,POINTER +ACL_MEM_MALLOC_HUGE_FIRST = 0 +ACL_MEMCPY_HOST_TO_DEVICE = 1 +ACL_MEMCPY_DEVICE_TO_HOST = 2 +ACL_MEM_MALLOC_NORMAL_ONLY = 2 +NPY_FLOAT32 = 11 + +libc = ctypes.CDLL("libc.so.6") +# mmap函数原型 +mmap_func = libc.mmap +mmap_func.argtypes = [c_void_p, c_size_t, c_int, c_int, c_int, c_int64] +mmap_func.restype = c_void_p + +# munmap函数原型 +munmap_func = libc.munmap +munmap_func.argtypes = [c_void_p, c_size_t] +munmap_func.restype = c_int + +def mmap_file(file_path): + # 打开文件并获取文件描述符 + file_descriptor = os.open(file_path, os.O_RDONLY) + file_size = os.lseek(file_descriptor, 0, os.SEEK_END) + os.lseek(file_descriptor, 0, os.SEEK_SET) + # 调用mmap映射文件到内存 + # PROT_READ和MAP_PRIVATE的值可能因系统而异,这里假设为1和2 + protection_flags = 1 # PROT_READ + visibility_flags = 2 # MAP_PRIVATE + mapped_memory = mmap_func(None, file_size, protection_flags, visibility_flags, file_descriptor, 0) + if mapped_memory == -1: + raise Exception("Error mapping the file.") + + # 关闭文件描述符,映射区域仍然有效 + os.close(file_descriptor) + + # 返回映射区域的地址 + return mapped_memory,file_size + +def check_ret(str,ret): + if ret != 0: + print(f"return code is {ret}, detail: {str}",flush=True) + +def init_resource(device_id: int): + ret = acl.init() + check_ret("init", ret) + ret = acl.rt.set_device(device_id) + check_ret("set_device", ret) + context,ret = acl.rt.create_context(device_id) + check_ret("create_context", ret) + return context + +def destroy_resource(device_id: int, context): + ret = acl.rt.reset_device(device_id) + ret = acl.finalize() + ret = acl.rt.destroy_context(context) + +dtype2NpType = {0:np.float32,1:np.float16,2:np.int8,3:np.int32,9:np.int64} + +class ACLModel: + def __init__(self, config: InferenceConfig, context=None,callback=None): + self.context = context + self.model_id = None + self.model_desc = None + self.callback_func = callback + self.tid = None + self.stream = None + self.callback_interval = 1 + self.exit_flag = False + self.kv_cache = None + self.input_dataset, self.output_dataset = None, None + self.inputs:List[Dict[str,]] = [] + self.outputs:List[Dict[str,]] = [] + self.config = config + self.load_model(config.om_model_path) + self.allocate_memory() + if not callback: + return + self.stream, ret = acl.rt.create_stream() + self.tid, ret = acl.util.start_thread(self._process_callback, + [self.context, 50]) + check_ret("acl.util.start_thread", ret) + ret = acl.rt.subscribe_report(self.tid, self.stream) + check_ret("acl.rt.subscribe_report", ret) + + def unload(self): + if self.callback_func: + ret = acl.rt.synchronize_stream(self.stream) + # 2.7 取消线程注册,Stream上的回调函数不再由指定线程处理。 + ret = acl.rt.unsubscribe_report(self.tid, self.stream) + self.exit_flag = True + ret = acl.util.stop_thread(self.tid) + ret = acl.rt.destroy_stream(self.stream) + self.free_memory() + self.unload_model() + + + def load_model(self, model_path): + """ + 加载模型 + Args: + model_path (_type_): _description_ + """ + model_add, model_size = mmap_file(model_path) + self.model_id, ret = acl.mdl.load_from_mem(model_add, model_size) + + #self.model_id, ret = acl.mdl.load_from_file(model_path) + check_ret("load model",ret) + munmap_func(model_add, model_size) + self.model_desc = acl.mdl.create_desc() + ret = acl.mdl.get_desc(self.model_desc, self.model_id) + check_ret("get model desc",ret) + + def unload_model(self): + """ + 卸载模型 + """ + ret = acl.mdl.unload(self.model_id) + if self.model_desc: + ret = acl.mdl.destroy_desc(self.model_desc) + self.model_desc = None + + def allocate_memory(self): + """ + 分配内存 + """ + self.input_dataset = acl.mdl.create_dataset() + input_size = acl.mdl.get_num_inputs(self.model_desc) + self.inputs = [] + for i in range(input_size): + buffer_size = acl.mdl.get_input_size_by_index(self.model_desc, i) + if i == 3: + buffer, ret = acl.rt.malloc(buffer_size, ACL_MEM_MALLOC_HUGE_FIRST) + self.kv_cache = acl.util.ptr_to_numpy( + buffer, self.config.past_key_value_shape, 23 # 23:NPY_HALF,NPY_FLOAT16 + ) + data = acl.create_data_buffer(buffer, buffer_size) + _, ret = acl.mdl.add_dataset_buffer(self.input_dataset, data) + check_ret("add_dataset_buffer",ret) + self.inputs.append({"buffer": buffer, "size": buffer_size}) + else: + buffer, ret = acl.rt.malloc(buffer_size, ACL_MEM_MALLOC_HUGE_FIRST) + check_ret("alloc input memory",ret) + data = acl.create_data_buffer(buffer, buffer_size) + _, ret = acl.mdl.add_dataset_buffer(self.input_dataset, data) + check_ret("add_dataset_buffer",ret) + self.inputs.append({"buffer": buffer, "size": buffer_size}) + + self.output_dataset = acl.mdl.create_dataset() + output_size = acl.mdl.get_num_outputs(self.model_desc) + self.outputs = [] + for i in range(output_size): + buffer_size = acl.mdl.get_output_size_by_index(self.model_desc, i) + data_type = acl.mdl.get_output_data_type(self.model_desc, i) + buffer, ret = acl.rt.malloc(buffer_size, ACL_MEM_MALLOC_HUGE_FIRST) + check_ret("alloc output memory",ret) + data = acl.create_data_buffer(buffer, buffer_size) + _, ret = acl.mdl.add_dataset_buffer(self.output_dataset, data) + check_ret("add_dataset_buffer",ret) + buffer_host, ret = acl.rt.malloc_host(buffer_size) + check_ret("alloc output host memory",ret) + self.outputs.append( + { + "buffer": buffer, + "size": buffer_size, + "buffer_host":buffer_host, + 'dtype':dtype2NpType[data_type] + } + ) + + def free_memory(self): + """ + 释放内存 + """ + for item in self.input_data: + ret = acl.rt.free(item["buffer"]) + ret = acl.mdl.destroy_dataset(self.input_dataset) + for item in self.output_data: + ret = acl.rt.free(item["buffer"]) + ret = acl.rt.free_host(item["buffer_host"]) + ret = acl.mdl.destroy_dataset(self.output_dataset) + + def inference(self,data) -> List[np.ndarray]: + """ + 执行推理,同步方式 + Args: + data (_type_): _description_ + + Returns: + List[np.ndarray]: _description_ + """ + start = time.time() + acl.rt.set_context(self.context) + for i in range(len(data)): + if i == 3: + pass + else: + bytes_data = data[i].tobytes() + np_ptr = acl.util.bytes_to_ptr(bytes_data) + ret = acl.rt.memcpy( + self.inputs[i]["buffer"], + self.inputs[i]["size"], + np_ptr, + self.inputs[i]["size"], + ACL_MEMCPY_HOST_TO_DEVICE + ) + check_ret("memcpy", ret) + ret = acl.mdl.execute( + self.model_id, + self.input_dataset, + self.output_dataset + ) + inference_result = [] + for out in self.outputs: + ret = acl.rt.memcpy( + out['buffer_host'], + out["size"], + out["buffer"], + out["size"], + ACL_MEMCPY_DEVICE_TO_HOST + ) + bytes_out = acl.util.ptr_to_bytes(out['buffer_host'], out["size"]) + out_data = np.frombuffer(bytes_out, dtype=out['dtype']) + inference_result.append(out_data) + return inference_result + + def inference_async(self, data, other_args) -> List[np.ndarray]: + """ + 执行推理,异步方式 + Args: + data (_type_): _description_ + other_args (_type_): _description_ + + Returns: + List[np.ndarray]: _description_ + """ + acl.rt.set_context(self.context) + # print(f"wait lock {other_args[1]}",flush=True) + # self.lock.acquire() + # print(f"get lock {other_args[1]}",flush=True) + for i in range(len(data)): + bytes_data = data[i].tobytes() + np_ptr = acl.util.bytes_to_ptr(bytes_data) + ret = acl.rt.memcpy( + self.inputs[i]["buffer"], + self.inputs[i]["size"], + np_ptr, + self.inputs[i]["size"], + ACL_MEMCPY_HOST_TO_DEVICE + ) + check_ret("memcpy", ret) + ret = acl.mdl.execute_async( + self.model_id, + self.input_dataset, + self.output_dataset, + self.stream + ) + check_ret("exec_async", ret) + print(f"submit exec task {other_args[1]}") + ret = acl.rt.launch_callback( + self.call_post_process, other_args, 1, self.stream + ) + check_ret("launch callback", ret) + + def _process_callback(self, args_list): + context, timeout = args_list + acl.rt.set_context(context) + while self.callback_interval: + acl.rt.process_report(timeout) + if self.exit_flag: + print("[Callback] exit acl.rt.process_report") + break + + def call_post_process(self,other_args): + print("start callback",flush=True) + time1 = time.time() + inference_result = [] + for out in self.outputs: + ret = acl.rt.memcpy( + out['buffer_host'], + out["size"], + out["buffer"], + out["size"], + ACL_MEMCPY_DEVICE_TO_HOST + ) + bytes_out = acl.util.ptr_to_bytes(out['buffer_host'], out["size"]) + data = np.frombuffer(bytes_out, dtype=out['dtype']) + inference_result.append(data) + # self.lock.release() + # print(f"free lock {other_args[1]}",flush=True) + if not self.callback_func: + return + self.callback_func(inference_result,other_args) + print(f"end callback, use time: {time.time()-time1}") diff --git a/utils/inference.py b/utils/inference.py new file mode 100644 index 0000000..7e45732 --- /dev/null +++ b/utils/inference.py @@ -0,0 +1,195 @@ +import numpy as np +import os +import time +import gc +from transformers import AutoTokenizer +from enum import Enum +from threading import Lock +from utils.session import Session +from config import InferenceConfig + + + +class Inference: + def __init__(self, config:InferenceConfig) -> None: + self.max_input_length = config.max_input_length + self.max_output_length = config.max_output_length + # self.tokenizer=Tokenizer(config.tokenizer) + self.tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer_dir, trust_remote_code=True + ) + self.sampling_method = config.sampling_method + self.sampling_value = config.sampling_value + self.temperature = config.temperature + self.session=Session.fromConfig(config) + # self.prompt=config.prompt + self.kv_cache_length = config.kv_cache_length + self.state: dict = {"code":200,"isEnd":False,"message":""} + self.reset() + self.lock = Lock() + self.first = True + # self.stop_mp = {"[|Human|]":6,"[|AI|]":5,"<|assistant|>":6,"<|user|>":5} + print("init success") + + + def generate_cache(self, prompt: str): + """ + 生成kv-cache + Args: + prompt (str): 提示词 + + Returns: + 返回下一个token与logits + """ + if len(prompt) == 0 : + return + self.first = False + input_ids = np.asarray( + self.tokenizer.encode(prompt), dtype=np.int64 + ).reshape(1,-1) + logits = self.session.run(input_ids)[0] + next_token = self.sample_logits( + logits[0][-1:], + self.sampling_method, + self.sampling_value, + self.temperature + ) + return next_token, logits + + def sample_logits( + self, + logits: np.ndarray, + sampling_method: str = "greedy", + sampling_value: float = None, + temperature: float = 1.0, + ) -> np.ndarray: + """ + 对logits做采样,得到下一个token + Args: + logits (np.ndarray): + sampling_method (str, optional): 采样方法,默认是"greedy",支持top_p, top_k + sampling_value (float, optional): _description_. Defaults to None. + temperature (float, optional): _description_. Defaults to 1.0. + + Raises: + Exception: _description_ + + Returns: + np.ndarray: _description_ + """ + if temperature == 0 or sampling_method == "greedy": + next_token = np.argmax(logits, axis=-1).astype(np.int64) + + elif sampling_method == "top_k" or sampling_method == "top_p": + assert sampling_value is not None + logits = logits.astype(np.float32) + logits /= temperature + probs = np.exp(logits) / np.sum(np.exp(logits)) + sorted_probs = np.sort(probs)[:, ::-1] + sorted_indices = np.argsort(probs)[:, ::-1] + + if sampling_method == "top_k": + index_of_interest = int(sampling_value) + elif sampling_method == "top_p": + p = sampling_value + cumulative_probs = np.cumsum(sorted_probs, axis=-1) + for index_of_interest, cumulative_prob in enumerate( + cumulative_probs[0] + ): + if cumulative_prob > p: + break + + probs_of_interest = sorted_probs[:, : index_of_interest + 1] + indices_of_interest = sorted_indices[:, : index_of_interest + 1] + probs_of_interest /= np.sum(probs_of_interest) + next_token = np.array( + [np.random.choice(indices_of_interest[0], p=probs_of_interest[0])] + ) + else: + raise Exception(f"Unknown sampling method {sampling_method}") + + return next_token + + def stream_predict( + self, + prompt, + history=None, + system_prompt: str="You are a helpful assistant.", + ): + if history is None: + history = [] + if len(history) == 0: + history = [{"role": "system", "content": system_prompt}] + # print("prompt: ", prompt) + with self.lock: + self.state['isEnd'],self.state['message'] = False,"" + if prompt == "": + return + history.append({"role": "user", "content": prompt}) + # print("history: ", history) + text = self.tokenizer.apply_chat_template( + history, + tokenize=False, + add_generation_prompt=True + ) + input_ids = self.tokenizer( + [text], return_tensors="np" + )["input_ids"].astype(np.int64).reshape(1, -1) + self.first = False + ids_list = [] + text_length = 0 + input_length = input_ids.shape[1] + start = time.time() + first_token_latency = 0 + decode_speed = 0 + max_output_len = self.max_output_length - input_length + for i in range(max_output_len): + logits = self.session.run(input_ids)[0] + input_ids = self.sample_logits(logits[0][-1:], self.sampling_method, self.sampling_value, self.temperature) + input_ids = input_ids.reshape(1, -1) + if i == 0: + first_token_latency = time.time() - start + with self.lock: + # early stop + if input_ids[0] == self.tokenizer.eos_token_id: + self.state['message'],self.state['isEnd'] = self.tokenizer.decode(ids_list),True + break + ids_list.append(input_ids[0].item()) + text_out = self.tokenizer.decode(ids_list) + # stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"]) + self.state['message'] = text_out + new_text = text_out[text_length: ] + # decode_speed = + duration = time.time() - start + decode_speed = len(ids_list) / duration + totol_speed = (input_length + len(ids_list)) / duration + if b"\xef\xbf\xbd" in new_text.encode(): + continue + if len(new_text) > 0: + yield new_text, first_token_latency, decode_speed, totol_speed + text_length = len(text_out) + with self.lock: + self.state['isEnd'] = True + + def reset(self): + self.first = True + self.session.run_times = 0 + self.session.reset() + # self.generate_cache(self.prompt) + + + def getState(self): + with self.lock: + return self.state.copy() + +# def preprocess(text:str) -> str: +# # 将输入转换为指定格式 +# return f"<|user|>\n{text}\n<|assistant|>" +# +# +# def is_stop_word_or_prefix(s: str, stop_words: list) -> int: +# for stop_word in stop_words: +# if s.endswith(stop_word): +# return stop_word +# return "" +# \ No newline at end of file diff --git a/utils/kvcache.py b/utils/kvcache.py new file mode 100644 index 0000000..ad77f42 --- /dev/null +++ b/utils/kvcache.py @@ -0,0 +1,289 @@ +import numpy as np +from typing import Optional,Tuple,List +from config import InferenceConfig +# 对KV缓存和输出输出格式进行管理 +class KVCacheManger: + def __init__(self, config: InferenceConfig) -> None: + self.num_key_value_heads = config.num_key_value_heads # head len + self.kv_cache_length = config.kv_cache_length # max_size + self.input_pos = 0 + self.past_kv_size = 0 + self.num_hidden_layers = config.num_hidden_layers # n_layer + self.cache_format = config.cache_format + self.num_key_value_heads = config.num_key_value_heads # head_num + self.per_head_dim = config.per_head_dim # head_dim + self.past_key_value_shape = config.past_key_value_shape + self.real_kv_size = 0 # 真正的kv_cache长度 + if config.dtype == "float16": + self.dtype=np.float16 + elif config.dtype=="float32": + self.dtype=np.float32 + else: + raise Exception("only support float16 and float32, not ", np.dtype) + # self.kv_cache = None + self.kv_cache = np.zeros(self.past_key_value_shape, dtype=self.dtype) + + def create_empty_cache(self): + """ + 创建空的kv_cache + """ + if self.cache_format == "huggingface-tensor": + self.kv_cache = np.zeros(self.past_key_value_shape, dtype=self.dtype) + + def update( + self, + seq_len: int, + new_kv_cache: Tuple[List[np.ndarray],List[np.ndarray]], + scores: Optional[np.ndarray]=None + )->None: + """ + 更新kv_cache,暂未实现,等子类实现 + Args: + seq_len (int): _description_ + newKV (Tuple[List[np.ndarray],List[np.ndarray]]): _description_ + scores (Optional[np.ndarray], optional): _description_. Defaults to None. + """ + pass + + def get_inputs(self, seq_len: int) -> List[np.ndarray]: + """ + 获取指定长度的kv_cache, 顺便生成mask和position_id + Args: + seq_len (int): 待获取的kv-cache长度 + + Returns: + List[np.ndarray]: _description_ + """ + + """ + self.kv_cache shape ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) + """ + cache = self.kv_cache[:, :, :, :, :self.past_kv_size] + mask = np.ones((1,self.past_kv_size + seq_len),dtype=np.int64) + mask[:, self.real_kv_size: self.past_kv_size] = 0 + pos_id =np.arange( + self.input_pos, + self.input_pos + seq_len, + dtype=np.int64 + ).reshape(1,-1) + return cache, mask, pos_id + + def reset(self,num=1): + self.input_pos=0 + self.real_kv_size=0 + if num != 0: + self.create_empty_cache() + + def rollback(self,seq_len): + self.real_kv_size -=seq_len + + + +class BasicKVCache(KVCacheManger): + def __init__(self, cfg: InferenceConfig) -> None: + super().__init__(cfg) + + def update( + self, + seq_len: int, + new_kv_cache: Tuple[List[np.ndarray]], + scores: Optional[np.ndarray] = None + ) -> None: + """ + 更新kv_cache + Args: + seq_len (int): 新kv-cache的长度 + new_kv_cache (Tuple[List[np.ndarray]]): 新的kv-cache + scores (Optional[np.ndarray], optional): _description_. Defaults to None. + + Raises: + RuntimeError: _description_ + """ + """ + self.kv_cache shape ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) + """ + if seq_len + self.past_kv_size > self.kv_cache_length: + raise RuntimeError("超出KV缓存长度限制") + if self.format=="huggingface-tensor": + temp_shape = list(self.past_key_value_shape) + temp_shape[-2] = -1 + new_kv_cache = new_kv_cache.reshape(temp_shape) + self.kv_cache[:, :, :, :, self.past_kv_size:self.past_kv_size + seq_len] = \ + new_kv_cache[:, :, :, :, 0:seq_len] + self.past_kv_size += seq_len + self.input_pos += seq_len + self.real_kv_size += seq_len + + def reset(self): + self.past_kv_size=0 + return super().reset() + +class FixSizeKVCache(KVCacheManger): + def __init__(self, cfg: InferenceConfig) -> None: + super().__init__(cfg) + # kv_cache的长度和max_output_length的长度一样 + self.past_kv_size=self.kv_cache_length + + def update( + self, + seq_len: int, + new_kv_cache: Tuple[List[np.ndarray]], + scores: Optional[np.ndarray] = None + ) -> None: + """ + self.kv_cache shape ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) + """ + self.input_pos = self.real_kv_size + seq_len + if seq_len + self.real_kv_size > self.kv_cache_length: + seq_len = self.kv_cache_length - self.real_kv_size + if seq_len <= 0: + return + if self.cache_format=="huggingface-tensor": + temp_shape = list(self.past_key_value_shape) + temp_shape[-2] = -1 + new_kv_cache = new_kv_cache.reshape(temp_shape) + self.kv_cache[:, :, :, :, self.real_kv_size: self.real_kv_size + seq_len] = \ + new_kv_cache[:, :, :, :, 0: seq_len] + self.real_kv_size += seq_len + +class FixSizeStreamLLM(KVCacheManger): + def __init__(self, cfg:InferenceConfig) -> None: + super().__init__(cfg) + self.past_len = 0 + self.past_kv_size=self.kv_cache_length + + def update( + self, + seq_len:int, + new_kv_cache: Tuple[List[np.ndarray],List[np.ndarray]], + score:Optional[np.ndarray] = None + ): + self.input_pos+=seq_len + while self.past_len+ seq_len > self.kv_cache_length: + self.update_part(new_kv_cache, self.past_len, self.kv_cache_length - self.past_len) + seq_len -= (self.kv_cache_length-self.past_len) + self.past_len= self.head_len + self.update_part(new_kv_cache, self.past_len, seq_len) + self.past_len+= seq_len + self.real_kv_size = max(self.past_len, self.real_kv_size) + + def update_part( + self, + new_kv_cache:Tuple[List[np.ndarray],List[np.ndarray]], + begin:int, + len:int + ): + """ + 局部更新kv-cache + Args: + new_kv_cache (Tuple[List[np.ndarray],List[np.ndarray]]): 待更新的新的kv-chace + begin (int): 更新起始位置 + len (int): 更新长度 + """ + + if self.cache_format == 'huggingface-tensor': #[n_layer,2,batch_size,head_num,len,head_dim] + self.kv_cache[:, :, :, :, self.past_len: self.past_len + len] = \ + new_kv_cache[:, :, :, :, begin: begin + len] + if self.cache_format =='seq_nhead_headdim': # [batch, n_layers, seq_len, n_heads, head_dim] + self.kv_cache[0][:, :, self.past_len: self.past_len + len] = \ + new_kv_cache[0][:, :, begin : begin+len] + self.kv_cache[1][:, :, self.past_len: self.past_len + len] = \ + new_kv_cache[1][:, :, begin: begin + len] + elif self.cache_format == 'nhead_seq_headdim': # [batch, n_layers, n_heads, seq_len, head_dim] + self.kv_cache[0][:, :, :, self.past_len: self.past_len + len] = \ + new_kv_cache[0][:, :, :, begin :begin + len] + self.kv_cache[1][:,:,:,self.past_len: self.past_len + len] = \ + new_kv_cache[1][:, :, :, begin: begin + len] + elif self.format=='huggingface-list': # (n_layer,2) * [batch_size,head_num,len,head_dim] + for i in range(self.num_hidden_layers): + self.kv_cache[i][0][:, :, self.past_len: self.past_len + len,:] = \ + new_kv_cache[i][0][:, :, begin: begin + len,:] + self.kv_cache[i][1][:, :, self.past_len: self.past_len + len,:] = \ + new_kv_cache[i][1][:, :, begin:begin + len,:] + + def reset(self): + self.past_len = 0 + self.real_kv_size = 0 + return super().reset() + +# 未完成 +# TODO: +class FixSizeH2O(KVCacheManger): + def __init__(self,cfg:InferenceConfig) -> None: + super().__init__(cfg) + self.scores = np.zeros((self.n_layer,1,self.head_num,self.past_kv_size),dtype=self.dtype) + + def update( + self, + new_kv_cache: Tuple[List[np.ndarray],List[np.ndarray]], + score: Optional[np.ndarray] = None + ): + """ + self.kv_cache shape ( + self.num_hidden_layers, + 2, + 1, + self.num_key_value_heads, + self.kv_cache_length, + self.per_head_dim + ) + """ + # score [n_layer,batch,nheader,input_len,all_len] + seq_len = new_kv_cache[0][0].shape[-2] + if self.real_kv_size + seq_len < self.past_kv_size: + self.kv_cache[:, :, :, :, self.real_kv_size: self.real_kv_size + seq_len,:] = new_kv_cache + self.real_kv_size += seq_len + self.scores[:, :, :, :self.real_kv_size] = \ + self.scores[:,:,:,:self.real_kv_size] * 0.5 + score[:, :, :, :self.real_kv_size] + score = score.sum(-1) + if self.format == 'huggingface-tensor': #[n_layer,2,batch_size,head_num,len,head_dim] + # self.kv_cache[:,:,:,:,self.p:self.p+len,:] = new_kv_cache[:,:,:,:,begin:begin+len,:] + for i in range(self.n_layer): + idx = np.argpartition(score[i],-seq_len) + self.kv_cache[i,:,idx,:] = new_kv_cache[i,:,idx,:] + self.scores[i,idx] = score[i,idx] + + def update_one( + self, + new_kv_cache:Tuple[List[np.ndarray],List[np.ndarray]], + score:Optional[np.ndarray], + ): + if self.real_kv_size < self.past_kv_size: + self.kv_cache[:, :, :, :, self.real_kv_size, :] = new_kv_cache + self.real_kv_size += 1 + self.scores[:, :, :, :self.real_kv_size] = \ + self.scores[:, :, :, :self.real_kv_size] * 0.5 + score[:, :, :, :self.real_kv_size] + + +def create_kv_cache(config: InferenceConfig) -> KVCacheManger: + if config.kvcache_method == "basic": + return BasicKVCache(config) + elif config.kvcache_method == "fixsize": + return FixSizeKVCache(config) + elif config.kvcache_method == 'streamllm': + return FixSizeStreamLLM(config) + elif config.kvcache_method == 'H2O': + return FixSizeH2O(config) + else: + return None diff --git a/utils/session.py b/utils/session.py new file mode 100644 index 0000000..5b10c6b --- /dev/null +++ b/utils/session.py @@ -0,0 +1,102 @@ +from config import InferenceConfig +from utils.kvcache import create_kv_cache +import numpy as np +from typing import List +import time +import sys +from utils.engine import ACLModel, init_resource, destroy_resource + +class Session: + def __init__(self, config: InferenceConfig) -> None: + self.kv_cache = create_kv_cache(config) + self.run_times = 0 + def run(self,input_ids:np.ndarray): + pass + + @staticmethod + def fromConfig(config:InferenceConfig) -> 'Session': + if config.session_type == "onnx": + return OnnxSession(config) + elif config.session_type=='acl': + return AclSession(config) + else: + return None + + def reset(self): + if self.run_times == 0: + self.kv_cache.reset(0) + else: + self.kv_cache.reset() + + def rollback(self,seq_len): + self.kv_cache.rollback(seq_len) + +class OnnxSession(Session): + def __init__(self,config:InferenceConfig)->None: + super().__init__(config) + import onnxruntime + options = onnxruntime.SessionOptions() + self.llm_session = onnxruntime.InferenceSession( + config.onnx_model_path, + sess_options=options, + providers=[ + "CPUExecutionProvider", + ], + ) + + def run(self, input_ids:np.ndarray): + seq_len=input_ids.shape[-1] + cache, mask, pos_ids = self.kv_cache.get_inputs(seq_len) + result = self.llm_session.run(None,{ + "input_ids": input_ids, + "attention_mask":mask, + "past_key_values": cache, + "position_ids": pos_ids, + }) + self.kv_cache.update(seq_len,result[1]) + return result + +class AclSession(Session): + context = None + def __init__(self, config:InferenceConfig): + super().__init__(config) + self.device_id = config.device_id + self.context = init_resource(self.device_id) + self.model = ACLModel(config, self.context) + self.input_ids = np.zeros((1,16),dtype=np.int64) + self.kv_cache.kv_cache = self.model.kv_cache + + def __del__(self): + destroy_resource(self.device_id, self.context) + def run(self, input_ids: np.ndarray): + seq_len = input_ids.shape[-1] + logits = None + for i in range(seq_len): + logits = self.run_one(input_ids[:,i]) + return [logits] + + def run_all_logits(self, input_ids: np.ndarray): + seq_len, i = input_ids.shape[-1], 0 + logits = [] + while i < seq_len: + end = i + 16 if i+16 < seq_len else seq_len + cache,mask,pos_ids = self.kv_cache.get_inputs(16) + self.input_ids[0:end-i] = input_ids[i:end] + result:List[np.ndarray] = self.model.inference([self.input_ids,pos_ids,mask,cache]) + self.kv_cache.update(end-i,result[1]) + logits.append(result[0][0:end-i].reshape(1,-1)) + return [np.concatenate(logits).reshape(1,1,-1)] + + def run_one(self, input_ids: np.ndarray): + self.run_times += 1 + cache, mask, pos_ids = self.kv_cache.get_inputs(1) + result:List[np.ndarray] = self.model.inference( + [input_ids, pos_ids, mask, cache] + ) + # new_kv_cache = result[1] + # print(" == Debug == ") + # print("new_kv_cache: shape", new_kv_cache.shape) + # print("new_kv_cache: mean: ", new_kv_cache.astype(np.float32).mean().item()) + # print("new_kv_cache: max: ", new_kv_cache.astype(np.float32).max().item()) + self.kv_cache.update(1,result[1]) + return result[0].reshape(1,1,-1)