Skip to content

Commit

Permalink
support auto detect soc version
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Aug 22, 2024
1 parent 62aefaa commit c13f415
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
--output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx"
```

5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。
5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1`
```bash
python3 export/onnx2om.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
Expand Down
57 changes: 39 additions & 18 deletions export/onnx2om.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import ctypes
import subprocess
import argparse
import math
Expand All @@ -17,6 +18,12 @@
os.mkdir(model_dir)

parser = argparse.ArgumentParser()
parser.add_argument(
'--soc_version',
type=str,
default="auto",
help="NPU full name, like Ascend310B1、Ascend310B4、Ascend310P1、Ascend910A、Ascend910B..., default is `auto`, will auto detect soc version.",
)
parser.add_argument(
'--hf_model_dir',
type=str,
Expand Down Expand Up @@ -60,29 +67,37 @@

args = parser.parse_args()


def get_soc_version():
"""
_summary_
获取芯片信息,返回具体的芯片型号
Returns:
_type_: _description_
"""
# 启动一个新的进程,并获取输出
result = subprocess.run(["npu-smi", "info"], capture_output=True, text=True)
# print(result.stdout)
line_list = result.stdout.split("\n")
soc_version = None
for line in line_list:
for data in line.split():
data = data.strip()
if data.startswith("310B") or data.startswith("310P") or data.startswith("910B"):
soc_version = data
break
if soc_version is not None:
break
assert soc_version is not None, print("soc_version", soc_version)
print("SoC Version is ", soc_version)
return soc_version
max_len = 512
rtsdll = ctypes.CDLL(f"libruntime.so")
c_char_t = ctypes.create_string_buffer(b"\xff" * max_len, max_len)
rtsdll.rtGetSocVersion.restype = ctypes.c_uint64
rt_error = rtsdll.rtGetSocVersion(c_char_t, ctypes.c_uint32(max_len))
if rt_error:
print("rt_error:", rt_error)
return ""
soc_full_name = c_char_t.value.decode("utf-8")
find_str = "Short_SoC_version="
ascend_home_dir = os.environ.get("ASCEND_HOME_PATH")
assert ascend_home_dir is not None, \
print("ASCEND_HOME_PATH is None, you need run `source /usr/local/Ascend/ascend-toolkit/set_env.sh`")
with open(f"{ascend_home_dir}/compiler/data/platform_config/{soc_full_name}.ini", "r") as f:
for line in f:
if find_str in line:
start_index = line.find(find_str)
soc_short_name = line[start_index + len(find_str):].strip()
return {
"soc_full_name": soc_full_name,
"soc_short_name": soc_short_name
}
raise Exception("can't get you soc version")

max_batch = args.max_batch
model_config = Qwen2Config.from_pretrained(args.hf_model_dir)
Expand Down Expand Up @@ -131,13 +146,19 @@ def get_soc_version():
per_head_dim
]
past_key_values_shape = [str(x) for x in past_key_values_shape]

if args.soc_version == "auto":
print("[INFO] soc_version is `auto`, will auto detect soc version")
soc_dict = get_soc_version()
print("[INFO] {}".format(soc_dict))
soc_version = soc_dict["soc_full_name"]
else:
soc_version = args.soc_version
command_lines = [
"atc",
"--framework=5",
'--model="{}"'.format(args.onnx_model_path),
'--output="{}"'.format(args.om_model_path),
"--soc_version=Ascend{}".format(get_soc_version()),
"--soc_version={}".format(soc_version),
"--precision_mode=must_keep_origin_dtype",
"--input_format=ND",
'--input_shape="input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}"'.format(
Expand Down

0 comments on commit c13f415

Please sign in to comment.