Skip to content

Commit

Permalink
add compare function, use 'mixed_float16' to replace 'must_keep_origi…
Browse files Browse the repository at this point in the history
…n_dtype'
  • Loading branch information
Tlntin committed Oct 31, 2024
1 parent 776b2e0 commit 1213961
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 32 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
.idea
.vscode
download/
output/
inference/
kernel_meta/
*/__pycache__/
__pycache__/
./*/__pycache__/
.idea
export/*.json
*.json
fusion_result.json
result/
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
--max_output_length=2048
```

4. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。
4. 改变onnx结构,目前导出的Trilu算子有些问题,atc命令无法识别,需要改一下结构。
```bash
python3 export/change_node.py \
--input_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \
Expand All @@ -112,7 +112,7 @@
5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**
- 运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。
- kv_cache_length长度和第一步导出onnx时的长度保持一致。
- `--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为4,如果设置为1,则不会开启动态shape推理功能。
- `--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为4,如果设置为1,则不会开启动态shape推理功能。**注意:开启动态shape后,模型体积会有50%-100%的增长,并且推理时占用的内存也会相应增长,如果对内存比较敏感,则建议关闭动态shape。**
- 该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1`
- `--kv_cache_length`的数值必须前面转onnx的时候指定的`--kv_cache_length`保持一致,否则大概率会转换失败。
- `--cpu_thread`为转onnx为om时,开启的cpu线程数,默认为1个线程并行编译,如果内存很多(每个线程单独占用一份内存,所以很费内存),可以调高一些。
Expand Down Expand Up @@ -173,6 +173,29 @@

- functional_call demo展示(使用qwen2-1.5b-instruct)![](./image/qwen2-1.5b-instruct-functional-call.jpg)

### (可选)对比onnx和om网络层结果
- 假设编译好的om文件推理输出异常(比如origin或者fp32精度正常,fp16异常),而onnx输出正常,我们需要找到异常的网络层结构,我们需要使用工具来导出onnx和om每一层的输入输出结果,看看是哪一层开始溢出或者结果差异较大。
- 这里我们可以采用昇腾官方提供的msit工具,下面是msit的开源主页:[链接](https://gitee.com/ascend/msit)
- 我们需要安装msit工具,安装方法参考官方网站:[链接](https://gitee.com/ascend/msit/tree/master/msit/docs/install)
```bash
pip install msit
msit install compare
```
- 安装完成后,开始做文件对比,对比的时候建议使用om静态图做对比,即转onnx为om时,设置max_prefill_length=1。
- 对比的时候,模型越小越好,建议可以用Qwen-0.5B-Instruct模型,这样可以节省时间,也方便分析。
- 对比方法参考官方网站:[链接](https://gitee.com/ascend/msit/tree/master/msit/docs/debug/compare#/ascend/msit/blob/master/msit/docs/install/README.md),目前我已经将其封装成了一个python代码,下面是一个示例:
```bash
python3 export/compare.py \
--hf_model_dir="./download/Qwen2-0.5B-Instruct" \
--onnx_model_path="./output/onnx2/qwen2_0.5b_chat.onnx" \
--om_model_path="./output/model/qwen2_0.5b_chat.om" \
--kv_cache_length=2048 \
--cpu_thread=1 \
--dtype="float16" \
--max_prefill_length=1
```
- 对比结果,参考官网网站说明:[链接](https://gitee.com/ascend/msit/blob/master/msit/examples/cli/debug/compare/result_analyse/README.md)

### 当前功能
- [x] 导出onnx, om模型
- [x] 模型推理,支持onnx推理(仅支持CPU)。
Expand Down
233 changes: 233 additions & 0 deletions export/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import os
import time
import subprocess
import numpy as np
import onnxruntime
import argparse
from transformers.models.qwen2 import Qwen2Tokenizer, Qwen2Config


now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)
result_output_dir = os.path.join(project_dir, "result")
input_data_dir = os.path.join(project_dir, "output", "input_data")
if not os.path.exists(result_output_dir):
os.mkdir(result_output_dir)
if not os.path.exists(input_data_dir):
os.mkdir(input_data_dir)

parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
help="float16 or float32",
choices=["float16", "float32"],
default="float32",
)
parser.add_argument(
'--hf_model_dir',
type=str,
help="model and tokenizer path, only support huggingface model",
default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct")
)
parser.add_argument(
"--onnx_model_path",
help="output onnx path",
type=str,
default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx")
)
parser.add_argument(
"--om_model_path",
help="mindspore model path",
type=str,
default= os.path.join(project_dir, "output", "model", "qwen2_1.5b_chat.om")
)
parser.add_argument(
"--kv_cache_length",
help="kv-cache length",
type=int,
default=2048,
)
parser.add_argument(
"--max_batch",
help="max batch",
type=int,
default=1,
)
parser.add_argument(
"--cpu_thread" ,
type=int,
help="num of cpu thread when convert onnx to om",
default=1,
)
parser.add_argument(
"--max_prefill_length",
help="max prefill length in first inference. "
"Attention max_prefill_length + max_output_length <= kv_cache_length. "
"the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... "
"Note! The higher this number, the longer it will take to compile.",
type=int,
default=8
)
args = parser.parse_args()

if args.dtype == "float16":
np_dtype = np.float16
elif args.dtype == "float32":
np_dtype = np.float32
else:
raise Exception("not support dtype, only support float16/float32")


def create_kv_cache(config: Qwen2Config, kv_cache_length=args.kv_cache_length):
return np.zeros(
[
1,
kv_cache_length,
config.num_hidden_layers * 2 * config.num_key_value_heads,
config.hidden_size // config.num_attention_heads
],
dtype=np_dtype
)


def get_inputs(kv_cache, seq_len: int, real_kv_size=0, input_pos=0, past_kv_size: int = args.kv_cache_length):
"""
获取指定长度的kv_cache, 顺便生成mask和position_id
Args:
kv_cache
seq_len (int): 待获取的kv-cache长度
real_kv_size: 真实kv_size长度
input_pos: 当前真实token所在位置
past_kv_size
Returns:
List[np.ndarray]: _description_
"""

"""
self.kv_cache shape (
1,
self.kv_cache_length,
self.num_hidden_layers * 2 * self.num_key_value_heads,
self.per_head_dim
)
"""
cache = kv_cache[:, :past_kv_size]
mask = np.ones((1, past_kv_size + seq_len), dtype=np.int64)
mask[:, real_kv_size: past_kv_size] = 0
pos_id = np.arange(
input_pos,
input_pos + seq_len,
dtype=np.int64
).reshape(1, -1)
return cache, mask, pos_id


tokenizer = Qwen2Tokenizer.from_pretrained(args.hf_model_dir)
model_config = Qwen2Config.from_pretrained(args.hf_model_dir)
prompt = "你好"
system_prompt: str = "You are a helpful assistant."
history = []
if len(history) == 0:
history = [{"role": "system", "content": system_prompt}]
history.append({"role": "user", "content": prompt})
print("history: ", history)
text = tokenizer.apply_chat_template(
history,
tokenize=False,
add_generation_prompt=True
)
print("raw_text", text)
input_ids = tokenizer(
[text], return_tensors="np"
)["input_ids"].astype(np.int64)[:, :1]
print("input_ids", input_ids)

# options = onnxruntime.SessionOptions()
# options.intra_op_num_threads = 4
# options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
# options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
#
# llm_session = onnxruntime.InferenceSession(
# args.onnx_model_path,
# sess_options=options,
# providers=[
# "CPUExecutionProvider",
# ],
# )

seq_len = input_ids.shape[-1]
kv_cache1 = create_kv_cache(model_config)
now_kv_cache, attn_mask, position_ids = get_inputs(kv_cache1, 1)
print("now_kv_cache shape: ", now_kv_cache.shape)
print("attention_mask shape: ", attn_mask.shape)
print("position_ids shape: ", position_ids.shape)
# save input data
# input_ids
input_ids_path = os.path.join(input_data_dir, "input_ids.npy")
np.save(input_ids_path, input_ids)
# attention_mask
attention_mask_path = os.path.join(input_data_dir, "attention_mask.npy")
np.save(attention_mask_path, attn_mask)
# position_ids
position_ids_path = os.path.join(input_data_dir, "position_ids.npy")
np.save(position_ids_path, position_ids)
# past_key_values
past_key_values_path = os.path.join(input_data_dir, "past_key_values.npy")
np.save(past_key_values_path, now_kv_cache)
input_path_list = [input_ids_path, attention_mask_path, position_ids_path, past_key_values_path]

max_batch = args.max_batch
max_prefill_length = args.max_prefill_length
kv_cache_length = args.kv_cache_length
model_config = Qwen2Config.from_pretrained(args.hf_model_dir)
num_hidden_layers = model_config.num_hidden_layers
num_key_value_heads = model_config.num_key_value_heads
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
per_head_dim = hidden_size // num_attention_heads

input_ids_shape = [
str(max_batch),
str(max_prefill_length)
]
attention_mask_shape = [
str(max_batch),
str(max_prefill_length + kv_cache_length)
]
position_ids_shape = [
str(max_batch),
str(max_prefill_length)
]
past_key_values_shape = [
str(max_batch),
str(kv_cache_length),
str(num_hidden_layers * 2 * num_key_value_heads),
str(per_head_dim)
]


command_lines = [
"msit debug compare",
"-gm {}".format(args.onnx_model_path),
"-om {}".format(args.om_model_path),
"-c /usr/local/Ascend/ascend-toolkit/latest",
# '--input \"{}\"'.format(",".join(input_path_list)),
'--input-shape \"input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}\"'.format(
",".join(input_ids_shape),
",".join(attention_mask_shape),
",".join(position_ids_shape),
",".join(past_key_values_shape)
),
"-o {}".format(result_output_dir),
"--advisor"
]
print("============ run command ==============")
print(" \\\r\n ".join(command_lines))
print("=======================================")
subprocess.run(
" \\\n ".join(command_lines),
shell=True,
check=True,
)
33 changes: 17 additions & 16 deletions export/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,16 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len += past_key_value.shape[1]
kv_seq_len = key_states.shape[-2] + past_key_value.shape[2]
# if past_key_value is not None:
# if self.layer_idx is None:
# raise ValueError(
# f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
# "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
# "with a layer index."
# )
# # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# kv_seq_len += past_key_value.shape[2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
output_cache = (key_states, value_states)
Expand Down Expand Up @@ -756,8 +756,8 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
"unexpected results may be encountered."
)
# self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.self_attn = Qwen2SdpaAttention(config, layer_idx)
# self.self_attn = Qwen2Attention(config, layer_idx)
# self.self_attn = Qwen2SdpaAttention(config, layer_idx)
self.self_attn = Qwen2Attention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -1109,14 +1109,15 @@ def forward(
attention_mask,
)
# === if use Qwen2Attention ===
# dtype = past_key_values.dtype
# device = input_ids.device
# attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device)
dtype = past_key_values.dtype
device = input_ids.device
attention_mask = torch.zeros_like(full_attention_mask, dtype=dtype).to(device)
# attention_mask.masked_fill_(full_attention_mask, torch.finfo(dtype).min)
attention_mask.masked_fill_(full_attention_mask, -10000.0)

# == if use Qwen2SdpaAttention ===
# copy from chatglm3-6b
attention_mask = ~full_attention_mask
# attention_mask = ~full_attention_mask

hidden_states = inputs_embeds

Expand Down
Loading

0 comments on commit 1213961

Please sign in to comment.