Skip to content

Commit

Permalink
code optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 21, 2024
1 parent d2dd231 commit d824498
Show file tree
Hide file tree
Showing 13 changed files with 2,167 additions and 2,091 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
2. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数。
```bash
python3 export/export_onnx.py \
--device_str=npu \
--dtype=float16 \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \
--kv_cache_length=1024
Expand Down
3 changes: 2 additions & 1 deletion cli_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ def inference_cli():
break
if input_text == 'clear':
history = []
infer_engine.session.reset()
print("Output: 已清理历史对话信息。")
continue
print("Output: ", end='')
response = ""
is_first = True
first_token_lantency, decode_speed = 0, 0
first_token_lantency, decode_speed, total_speed = 0, 0, 0.0
for (
new_text,
first_token_lantency,
Expand Down
6 changes: 3 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def __init__(
hf_model_dir: str,
om_model_path: str,
onnx_model_path: str,
cpu_thread: int = 4, # CPU线程数
session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language
device_id: int = 0,
sampling_method: str = "top_p", # 支持 greedy, top_p, top_k
Expand All @@ -30,6 +31,7 @@ def __init__(
assert os.path.exists(onnx_model_path), print(onnx_model_path, "not exists")
self.om_model_path = om_model_path
self.onnx_model_path = onnx_model_path
self.cpu_thread = cpu_thread
self.device_id = device_id
self.sampling_method = sampling_method
self.sampling_value = sampling_value
Expand All @@ -48,11 +50,9 @@ def __init__(
self.num_attention_heads = self.model_config.num_attention_heads
self.per_head_dim = self.hidden_size // self.num_attention_heads # head_dim
self.past_key_value_shape = (
self.num_hidden_layers,
2,
self.max_batch,
self.num_key_value_heads,
self.kv_cache_length,
self.num_hidden_layers * 2 * self.num_key_value_heads,
self.per_head_dim
)
self.max_prefill_length = max_prefill_length
Expand Down
164 changes: 82 additions & 82 deletions export/change_node.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,82 @@
import os
import onnx
import onnx.helper as helper
from onnx import TensorProto
from tqdm import tqdm
import argparse


now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)
output_dir = os.path.join(project_dir, "output")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
old_onnx_dir = os.path.join(output_dir, "onnx")
if not os.path.exists(old_onnx_dir):
os.mkdir(old_onnx_dir)
new_onnx_dir = os.path.join(output_dir, "onnx2")
if not os.path.exists(new_onnx_dir):
os.mkdir(new_onnx_dir)

now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)
model_name = "qwen2_1.5b_chat.onnx"

parser = argparse.ArgumentParser()
parser.add_argument(
'--input_model_path',
type=str,
help="raw onnx model convert by pytroch",
default=os.path.join(old_onnx_dir, model_name)
)
parser.add_argument(
"--output_model_path",
help="output onnx model path",
type=str,
default=os.path.join(new_onnx_dir, model_name)
)

args = parser.parse_args()

model = onnx.load(args.input_model_path)
new_nodes = []

for node in tqdm(model.graph.node, desc="replace node..."):
# 判断节点类型
new_node = node
if node.op_type == "Trilu":
new_node = helper.make_node(
"Trilu",
name="MY_" + node.name,
inputs=[node.input[0]],
outputs=node.output,
upper=0
)
if node.op_type == "Cast":
# 替换为新的算子类型
to_attribute = next(attr for attr in node.attribute if attr.name == "to")
if to_attribute.i == TensorProto.INT8:
new_node = helper.make_node(
"AscendQuant",
inputs=node.input,
outputs=node.output,
offset=0.,
scale=1.,
)
new_nodes.append(new_node)
print("make new graph")
new_graph = helper.make_graph(
new_nodes,
"new_graph",
inputs=model.graph.input,
outputs=model.graph.output,
value_info=model.graph.value_info,
initializer=model.graph.initializer
)
print("make new model")
new_model = helper.make_model(new_graph, producer_name=model.producer_name,opset_imports=model.opset_import,ir_version = model.ir_version)
# new_model.ir_version = model.ir_version
# new_model.opset_import = model.opset_import
# new_model.metadata_props = model.metadata_props
print("will save model in ", args.output_model_path)
onnx.save(new_model, args.output_model_path, save_as_external_data=True)
import os
import onnx
import onnx.helper as helper
from onnx import TensorProto
from tqdm import tqdm
import argparse


now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)
output_dir = os.path.join(project_dir, "output")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
old_onnx_dir = os.path.join(output_dir, "onnx")
if not os.path.exists(old_onnx_dir):
os.mkdir(old_onnx_dir)
new_onnx_dir = os.path.join(output_dir, "onnx2")
if not os.path.exists(new_onnx_dir):
os.mkdir(new_onnx_dir)

now_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.dirname(now_dir)
model_name = "qwen2_1.5b_chat.onnx"

parser = argparse.ArgumentParser()
parser.add_argument(
'--input_model_path',
type=str,
help="raw onnx model convert by pytroch",
default=os.path.join(old_onnx_dir, model_name)
)
parser.add_argument(
"--output_model_path",
help="output onnx model path",
type=str,
default=os.path.join(new_onnx_dir, model_name)
)

args = parser.parse_args()

model = onnx.load(args.input_model_path)
new_nodes = []

for node in tqdm(model.graph.node, desc="replace node..."):
# 判断节点类型
new_node = node
if node.op_type == "Trilu":
new_node = helper.make_node(
"Trilu",
name="MY_" + node.name,
inputs=[node.input[0]],
outputs=node.output,
upper=0
)
if node.op_type == "Cast":
# 替换为新的算子类型
to_attribute = next(attr for attr in node.attribute if attr.name == "to")
if to_attribute.i == TensorProto.INT8:
new_node = helper.make_node(
"AscendQuant",
inputs=node.input,
outputs=node.output,
offset=0.,
scale=1.,
)
new_nodes.append(new_node)
print("make new graph")
new_graph = helper.make_graph(
new_nodes,
"new_graph",
inputs=model.graph.input,
outputs=model.graph.output,
value_info=model.graph.value_info,
initializer=model.graph.initializer
)
print("make new model")
new_model = helper.make_model(new_graph, producer_name=model.producer_name,opset_imports=model.opset_import,ir_version = model.ir_version)
# new_model.ir_version = model.ir_version
# new_model.opset_import = model.opset_import
# new_model.metadata_props = model.metadata_props
print("will save model in ", args.output_model_path)
onnx.save(new_model, args.output_model_path, save_as_external_data=True)
Loading

0 comments on commit d824498

Please sign in to comment.