From 12139610ceb848798e607a6ad0a7bd11b173c7b7 Mon Sep 17 00:00:00 2001 From: Tlntin Date: Thu, 31 Oct 2024 15:59:51 +0800 Subject: [PATCH] add compare function, use 'mixed_float16' to replace 'must_keep_origin_dtype' --- .gitignore | 6 +- README.md | 27 ++++- export/compare.py | 233 +++++++++++++++++++++++++++++++++++++++ export/modeling_qwen2.py | 33 +++--- export/onnx2om.py | 21 ++-- ops_info.json | 10 ++ utils/inference.py | 6 +- utils/session.py | 2 +- 8 files changed, 306 insertions(+), 32 deletions(-) create mode 100644 export/compare.py create mode 100644 ops_info.json diff --git a/.gitignore b/.gitignore index d3b082c..33d03e5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.idea +.vscode download/ output/ inference/ @@ -5,6 +7,6 @@ kernel_meta/ */__pycache__/ __pycache__/ ./*/__pycache__/ -.idea export/*.json -*.json \ No newline at end of file +fusion_result.json +result/ \ No newline at end of file diff --git a/README.md b/README.md index e8e8339..62506bc 100644 --- a/README.md +++ b/README.md @@ -102,7 +102,7 @@ --max_output_length=2048 ``` -4. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。 +4. 改变onnx结构,目前导出的Trilu算子有些问题,atc命令无法识别,需要改一下结构。 ```bash python3 export/change_node.py \ --input_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \ @@ -112,7 +112,7 @@ 5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。 - 运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。 - kv_cache_length长度和第一步导出onnx时的长度保持一致。 - - `--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为4,如果设置为1,则不会开启动态shape推理功能。 + - `--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为4,如果设置为1,则不会开启动态shape推理功能。**注意:开启动态shape后,模型体积会有50%-100%的增长,并且推理时占用的内存也会相应增长,如果对内存比较敏感,则建议关闭动态shape。** - 该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1` - `--kv_cache_length`的数值必须前面转onnx的时候指定的`--kv_cache_length`保持一致,否则大概率会转换失败。 - `--cpu_thread`为转onnx为om时,开启的cpu线程数,默认为1个线程并行编译,如果内存很多(每个线程单独占用一份内存,所以很费内存),可以调高一些。 @@ -173,6 +173,29 @@ - functional_call demo展示(使用qwen2-1.5b-instruct)![](./image/qwen2-1.5b-instruct-functional-call.jpg) +### (可选)对比onnx和om网络层结果 +- 假设编译好的om文件推理输出异常(比如origin或者fp32精度正常,fp16异常),而onnx输出正常,我们需要找到异常的网络层结构,我们需要使用工具来导出onnx和om每一层的输入输出结果,看看是哪一层开始溢出或者结果差异较大。 +- 这里我们可以采用昇腾官方提供的msit工具,下面是msit的开源主页:[链接](https://gitee.com/ascend/msit) +- 我们需要安装msit工具,安装方法参考官方网站:[链接](https://gitee.com/ascend/msit/tree/master/msit/docs/install) + ```bash + pip install msit + msit install compare + ``` +- 安装完成后,开始做文件对比,对比的时候建议使用om静态图做对比,即转onnx为om时,设置max_prefill_length=1。 +- 对比的时候,模型越小越好,建议可以用Qwen-0.5B-Instruct模型,这样可以节省时间,也方便分析。 +- 对比方法参考官方网站:[链接](https://gitee.com/ascend/msit/tree/master/msit/docs/debug/compare#/ascend/msit/blob/master/msit/docs/install/README.md),目前我已经将其封装成了一个python代码,下面是一个示例: + ```bash + python3 export/compare.py \ + --hf_model_dir="./download/Qwen2-0.5B-Instruct" \ + --onnx_model_path="./output/onnx2/qwen2_0.5b_chat.onnx" \ + --om_model_path="./output/model/qwen2_0.5b_chat.om" \ + --kv_cache_length=2048 \ + --cpu_thread=1 \ + --dtype="float16" \ + --max_prefill_length=1 + ``` +- 对比结果,参考官网网站说明:[链接](https://gitee.com/ascend/msit/blob/master/msit/examples/cli/debug/compare/result_analyse/README.md) + ### 当前功能 - [x] 导出onnx, om模型 - [x] 模型推理,支持onnx推理(仅支持CPU)。 diff --git a/export/compare.py b/export/compare.py new file mode 100644 index 0000000..ee95a49 --- /dev/null +++ b/export/compare.py @@ -0,0 +1,233 @@ +import os +import time +import subprocess +import numpy as np +import onnxruntime +import argparse +from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2Config + + +now_dir = os.path.dirname(os.path.abspath(__file__)) +project_dir = os.path.dirname(now_dir) +result_output_dir = os.path.join(project_dir, "result") +input_data_dir = os.path.join(project_dir, "output", "input_data") +if not os.path.exists(result_output_dir): + os.mkdir(result_output_dir) +if not os.path.exists(input_data_dir): + os.mkdir(input_data_dir) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--dtype", + type=str, + help="float16 or float32", + choices=["float16", "float32"], + default="float32", +) +parser.add_argument( + '--hf_model_dir', + type=str, + help="model and tokenizer path, only support huggingface model", + default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") +) +parser.add_argument( + "--onnx_model_path", + help="output onnx path", + type=str, + default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx") +) +parser.add_argument( + "--om_model_path", + help="mindspore model path", + type=str, + default= os.path.join(project_dir, "output", "model", "qwen2_1.5b_chat.om") +) +parser.add_argument( + "--kv_cache_length", + help="kv-cache length", + type=int, + default=2048, +) +parser.add_argument( + "--max_batch", + help="max batch", + type=int, + default=1, +) +parser.add_argument( + "--cpu_thread" , + type=int, + help="num of cpu thread when convert onnx to om", + default=1, +) +parser.add_argument( + "--max_prefill_length", + help="max prefill length in first inference. " + "Attention max_prefill_length + max_output_length <= kv_cache_length. " + "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " + "Note! The higher this number, the longer it will take to compile.", + type=int, + default=8 +) +args = parser.parse_args() + +if args.dtype == "float16": + np_dtype = np.float16 +elif args.dtype == "float32": + np_dtype = np.float32 +else: + raise Exception("not support dtype, only support float16/float32") + + +def create_kv_cache(config: Qwen2Config, kv_cache_length=args.kv_cache_length): + return np.zeros( + [ + 1, + kv_cache_length, + config.num_hidden_layers * 2 * config.num_key_value_heads, + config.hidden_size // config.num_attention_heads + ], + dtype=np_dtype + ) + + +def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = args.kv_cache_length): + """ + 获取指定长度的kv_cache, 顺便生成mask和position_id + Args: + kv_cache + seq_len (int): 待获取的kv-cache长度 + real_kv_size: 真实kv_size长度 + input_pos: 当前真实token所在位置 + past_kv_size + + Returns: + List[np.ndarray]: _description_ + """ + + """ + self.kv_cache shape ( + 1, + self.kv_cache_length, + self.num_hidden_layers * 2 * self.num_key_value_heads, + self.per_head_dim + ) + """ + cache = kv_cache[:, :past_kv_size] + mask = np.ones((1, past_kv_size + seq_len), dtype=np.int64) + mask[:, real_kv_size: past_kv_size] = 0 + pos_id = np.arange( + input_pos, + input_pos + seq_len, + dtype=np.int64 + ).reshape(1, -1) + return cache, mask, pos_id + + +tokenizer = Qwen2Tokenizer.from_pretrained(args.hf_model_dir) +model_config = Qwen2Config.from_pretrained(args.hf_model_dir) +prompt = "你好" +system_prompt: str = "You are a helpful assistant." +history = [] +if len(history) == 0: + history = [{"role": "system", "content": system_prompt}] +history.append({"role": "user", "content": prompt}) +print("history: ", history) +text = tokenizer.apply_chat_template( + history, + tokenize=False, + add_generation_prompt=True +) +print("raw_text", text) +input_ids = tokenizer( + [text], return_tensors="np" +)["input_ids"].astype(np.int64)[:, :1] +print("input_ids", input_ids) + +# options = onnxruntime.SessionOptions() +# options.intra_op_num_threads = 4 +# options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL +# options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL +# +# llm_session = onnxruntime.InferenceSession( +# args.onnx_model_path, +# sess_options=options, +# providers=[ +# "CPUExecutionProvider", +# ], +# ) + +seq_len = input_ids.shape[-1] +kv_cache1 = create_kv_cache(model_config) +now_kv_cache, attn_mask, position_ids = get_inputs(kv_cache1, 1) +print("now_kv_cache shape: ", now_kv_cache.shape) +print("attention_mask shape: ", attn_mask.shape) +print("position_ids shape: ", position_ids.shape) +# save input data +# input_ids +input_ids_path = os.path.join(input_data_dir, "input_ids.npy") +np.save(input_ids_path, input_ids) +# attention_mask +attention_mask_path = os.path.join(input_data_dir, "attention_mask.npy") +np.save(attention_mask_path, attn_mask) +# position_ids +position_ids_path = os.path.join(input_data_dir, "position_ids.npy") +np.save(position_ids_path, position_ids) +# past_key_values +past_key_values_path = os.path.join(input_data_dir, "past_key_values.npy") +np.save(past_key_values_path, now_kv_cache) +input_path_list = [input_ids_path, attention_mask_path, position_ids_path, past_key_values_path] + +max_batch = args.max_batch +max_prefill_length = args.max_prefill_length +kv_cache_length = args.kv_cache_length +model_config = Qwen2Config.from_pretrained(args.hf_model_dir) +num_hidden_layers = model_config.num_hidden_layers +num_key_value_heads = model_config.num_key_value_heads +hidden_size = model_config.hidden_size +num_attention_heads = model_config.num_attention_heads +per_head_dim = hidden_size // num_attention_heads + +input_ids_shape = [ + str(max_batch), + str(max_prefill_length) +] +attention_mask_shape = [ + str(max_batch), + str(max_prefill_length + kv_cache_length) +] +position_ids_shape = [ + str(max_batch), + str(max_prefill_length) +] +past_key_values_shape = [ + str(max_batch), + str(kv_cache_length), + str(num_hidden_layers * 2 * num_key_value_heads), + str(per_head_dim) +] + + +command_lines = [ + "msit debug compare", + "-gm {}".format(args.onnx_model_path), + "-om {}".format(args.om_model_path), + "-c /usr/local/Ascend/ascend-toolkit/latest", + # '--input \"{}\"'.format(",".join(input_path_list)), + '--input-shape \"input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}\"'.format( + ",".join(input_ids_shape), + ",".join(attention_mask_shape), + ",".join(position_ids_shape), + ",".join(past_key_values_shape) + ), + "-o {}".format(result_output_dir), + "--advisor" +] +print("============ run command ==============") +print(" \\\r\n ".join(command_lines)) +print("=======================================") +subprocess.run( + " \\\n ".join(command_lines), + shell=True, + check=True, +) \ No newline at end of file diff --git a/export/modeling_qwen2.py b/export/modeling_qwen2.py index fcb6de9..c3a6d3e 100644 --- a/export/modeling_qwen2.py +++ b/export/modeling_qwen2.py @@ -265,16 +265,16 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value.shape[1] + kv_seq_len = key_states.shape[-2] + past_key_value.shape[2] + # if past_key_value is not None: + # if self.layer_idx is None: + # raise ValueError( + # f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + # "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + # "with a layer index." + # ) + # # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # kv_seq_len += past_key_value.shape[2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) output_cache = (key_states, value_states) @@ -756,8 +756,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int): "unexpected results may be encountered." ) # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2SdpaAttention(config, layer_idx) - # self.self_attn = Qwen2Attention(config, layer_idx) + # self.self_attn = Qwen2SdpaAttention(config, layer_idx) + self.self_attn = Qwen2Attention(config, layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1109,14 +1109,15 @@ def forward( attention_mask, ) # === if use Qwen2Attention === - # dtype = past_key_values.dtype - # device = input_ids.device - # attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device) + dtype = past_key_values.dtype + device = input_ids.device + attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device) # attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min) + attention_mask.masked_fill_(full_attention_mask, -10000.0) # == if use Qwen2SdpaAttention === # copy from chatglm3-6b - attention_mask = ~full_attention_mask + # attention_mask = ~full_attention_mask hidden_states = inputs_embeds diff --git a/export/onnx2om.py b/export/onnx2om.py index b151e66..6bb2893 100644 --- a/export/onnx2om.py +++ b/export/onnx2om.py @@ -137,6 +137,12 @@ def get_soc_version(): f"1~{max_batch}" if max_batch > 1 else "1", "-1" if max_prefill_length > 1 else "1" ] +past_key_values_shape = [ + f"1~{max_batch}" if max_batch > 1 else "1", + "-1" if max_prefill_length > 1 else kv_cache_length, + num_hidden_layers * 2 * num_key_value_heads, + per_head_dim +] dynamic_dims = [] for dynamic_kv_cache_length in [ kv_cache_length // 2, @@ -150,12 +156,7 @@ def get_soc_version(): str(dynamic_kv_cache_length), # past_key_values ] dynamic_dims.append(",".join(new_dynamic_dim)) -past_key_values_shape = [ - f"1~{max_batch}" if max_batch > 1 else "1", - "-1" if max_prefill_length > 1 else kv_cache_length, - num_hidden_layers * 2 * num_key_value_heads, - per_head_dim -] + past_key_values_shape = [str(x) for x in past_key_values_shape] if args.soc_version == "auto": print("[INFO] soc_version is `auto`, will auto detect soc version") @@ -172,11 +173,15 @@ def get_soc_version(): "export MAX_COMPILE_CORE_NUMBER={} &&".format(args.cpu_thread), "atc", "--framework=5", - "--host_env_cpu=aarch64", + # "--host_env_cpu=aarch64", '--model="{}"'.format(args.onnx_model_path), '--output="{}"'.format(args.om_model_path), "--soc_version={}".format(soc_version), - "--precision_mode=must_keep_origin_dtype", + # "--precision_mode=must_keep_origin_dtype", + # "--precision_mode_v2=origin", + "--precision_mode_v2=mixed_float16", + "--modify_mixlist={}".format(os.path.join(project_dir, "ops_info.json")), + # "--op_select_implmode=high_precision_for_all", "--input_format=ND", '--input_shape="input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}"'.format( ",".join(input_ids_shape), diff --git a/ops_info.json b/ops_info.json new file mode 100644 index 0000000..f63f8b6 --- /dev/null +++ b/ops_info.json @@ -0,0 +1,10 @@ +{ + "black-list": { + "to-remove": [], + "to-add": ["Square"] + }, + "white-list": { + "to-remove": [], + "to-add": [] + } +} \ No newline at end of file diff --git a/utils/inference.py b/utils/inference.py index 7f4ec30..2fb0cb8 100644 --- a/utils/inference.py +++ b/utils/inference.py @@ -127,7 +127,7 @@ def stream_predict( history=None, sampling_config: dict = {}, system_prompt: str = "You are a helpful assistant.", - max_new_tokens: int = 512, + max_new_tokens: int = 1024, do_speed_test: bool = False, show_progress: bool = False, ): @@ -234,7 +234,7 @@ def predict( history=None, sampling_config: dict = {}, system_prompt: str="You are a helpful assistant.", - max_new_tokens: int = 512, + max_new_tokens: int = 1024, show_progress: bool = False, ): if history is None: @@ -322,7 +322,7 @@ def generate( self, input_ids, sampling_config: dict = {}, - max_new_tokens: int = 512, + max_new_tokens: int = 1024, show_progress: bool = False, ): sampling_value = sampling_config.get("sampling_value", self.sampling_value) diff --git a/utils/session.py b/utils/session.py index d5bb0b3..928e2e3 100644 --- a/utils/session.py +++ b/utils/session.py @@ -155,7 +155,7 @@ def __init__(self, config:InferenceConfig): self.prefill_log2_list = [2**index for index in self.prefill_log2_list] def reset(self): - self.model.reset(); + self.model.reset() def __del__(self): from utils.engine import destroy_resource