diff --git a/README.md b/README.md index 8b24931..5f2f3fe 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ --output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" ``` -5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。 +5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。该脚本会自动检测你的NPU类型,如果你想手动指定,可以加上`--soc_version=xxxx`来指定,例如`--soc_version=Ascend310B1` ```bash python3 export/onnx2om.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ diff --git a/export/onnx2om.py b/export/onnx2om.py index cd388fe..2fe17a6 100644 --- a/export/onnx2om.py +++ b/export/onnx2om.py @@ -1,4 +1,5 @@ import os +import ctypes import subprocess import argparse import math @@ -17,6 +18,12 @@ os.mkdir(model_dir) parser = argparse.ArgumentParser() +parser.add_argument( + '--soc_version', + type=str, + default="auto", + help="NPU full name, like Ascend310B1、Ascend310B4、Ascend310P1、Ascend910A、Ascend910B..., default is `auto`, will auto detect soc version.", +) parser.add_argument( '--hf_model_dir', type=str, @@ -60,6 +67,7 @@ args = parser.parse_args() + def get_soc_version(): """ _summary_ @@ -67,22 +75,29 @@ def get_soc_version(): Returns: _type_: _description_ """ - # 启动一个新的进程,并获取输出 - result = subprocess.run(["npu-smi", "info"], capture_output=True, text=True) - # print(result.stdout) - line_list = result.stdout.split("\n") - soc_version = None - for line in line_list: - for data in line.split(): - data = data.strip() - if data.startswith("310B") or data.startswith("310P") or data.startswith("910B"): - soc_version = data - break - if soc_version is not None: - break - assert soc_version is not None, print("soc_version", soc_version) - print("SoC Version is ", soc_version) - return soc_version + max_len = 512 + rtsdll = ctypes.CDLL(f"libruntime.so") + c_char_t = ctypes.create_string_buffer(b"\xff" * max_len, max_len) + rtsdll.rtGetSocVersion.restype = ctypes.c_uint64 + rt_error = rtsdll.rtGetSocVersion(c_char_t, ctypes.c_uint32(max_len)) + if rt_error: + print("rt_error:", rt_error) + return "" + soc_full_name = c_char_t.value.decode("utf-8") + find_str = "Short_SoC_version=" + ascend_home_dir = os.environ.get("ASCEND_HOME_PATH") + assert ascend_home_dir is not None, \ + print("ASCEND_HOME_PATH is None, you need run `source /usr/local/Ascend/ascend-toolkit/set_env.sh`") + with open(f"{ascend_home_dir}/compiler/data/platform_config/{soc_full_name}.ini", "r") as f: + for line in f: + if find_str in line: + start_index = line.find(find_str) + soc_short_name = line[start_index + len(find_str):].strip() + return { + "soc_full_name": soc_full_name, + "soc_short_name": soc_short_name + } + raise Exception("can't get you soc version") max_batch = args.max_batch model_config = Qwen2Config.from_pretrained(args.hf_model_dir) @@ -131,13 +146,19 @@ def get_soc_version(): per_head_dim ] past_key_values_shape = [str(x) for x in past_key_values_shape] - +if args.soc_version == "auto": + print("[INFO] soc_version is `auto`, will auto detect soc version") + soc_dict = get_soc_version() + print("[INFO] {}".format(soc_dict)) + soc_version = soc_dict["soc_full_name"] +else: + soc_version = args.soc_version command_lines = [ "atc", "--framework=5", '--model="{}"'.format(args.onnx_model_path), '--output="{}"'.format(args.om_model_path), - "--soc_version=Ascend{}".format(get_soc_version()), + "--soc_version={}".format(soc_version), "--precision_mode=must_keep_origin_dtype", "--input_format=ND", '--input_shape="input_ids:{};attention_mask:{};position_ids:{};past_key_values:{}"'.format(