Skip to content

Commit

Permalink
fixup a bug for acl runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 29, 2024
1 parent ea4590a commit 3c2e592
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 16 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
python3 ./cli_chat.py --hf_model_dir="download/[你下载的模型路径]"
```

- demo展示(演示模型,qwen1.5-0.5chat)
![](./image/qwen1.5_0.5b_chat.gif)


### 当前功能
- [x] 导出onnx, om模型
- [x] 模型推理,支持onnx推理。
- [ ] 模型推理,支持acl推理。
- [x] 模型推理,支持onnx推理(仅支持CPU)
- [x] 模型推理,支持acl推理。
- [x] 流式传输
- [ ] 兼容OpenAI的api搭建
- [ ] 支持functional call
Expand Down
1 change: 1 addition & 0 deletions export/onnx2om.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_soc_version():
if soc_version is not None:
break
assert soc_version is not None, print("soc_version", soc_version)
print("SoC Version is ", soc_version)
return soc_version


Expand Down
Binary file added image/qwen1.5_0.5b_chat.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions utils/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import acl
import numpy as np
import os
from functools import reduce
from operator import mul
import ctypes
from config import InferenceConfig
from ctypes import c_void_p, c_int, c_size_t, c_ulong, c_int64,POINTER


ACL_MEM_MALLOC_HUGE_FIRST = 0
ACL_MEMCPY_HOST_TO_DEVICE = 1
ACL_MEMCPY_DEVICE_TO_HOST = 2
Expand Down
33 changes: 19 additions & 14 deletions utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def __init__(self,config:InferenceConfig)->None:
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
# options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
self.llm_session = ort.InferenceSession(
config.onnx_model_path,
sess_options=options,
providers=[
(
config.onnx_model_path,
sess_options=options,
providers=[
(
"CANNExecutionProvider",
{
"device_id": 0,
Expand All @@ -80,9 +80,9 @@ def __init__(self,config:InferenceConfig)->None:
"enable_cann_graph": True
},
),
"CPUExecutionProvider",
]
)
"CPUExecutionProvider",
]
)
def run(self, input_ids:np.ndarray):
seq_len=input_ids.shape[-1]
Expand Down Expand Up @@ -131,7 +131,7 @@ def run_all_logits(self, input_ids: np.ndarray):
end = i + 16 if i+16 < seq_len else seq_len
cache,mask,pos_ids = self.kv_cache.get_inputs(16)
self.input_ids[0:end-i] = input_ids[i:end]
result:List[np.ndarray] = self.model.inference([self.input_ids,pos_ids,mask,cache])
result:List[np.ndarray] = self.model.inference([self.input_ids, mask, pos_ids, cache])
self.kv_cache.update(end-i,result[1])
logits.append(result[0][0:end-i].reshape(1,-1))
return [np.concatenate(logits).reshape(1,1,-1)]
Expand All @@ -140,12 +140,17 @@ def run_one(self, input_ids: np.ndarray):
self.run_times += 1
cache, mask, pos_ids = self.kv_cache.get_inputs(1)
result:List[np.ndarray] = self.model.inference(
[input_ids, pos_ids, mask, cache]
[input_ids, mask, pos_ids, cache]
)
# new_kv_cache = result[1]
# print(" == Debug == ")
# print("new_kv_cache: shape", new_kv_cache.shape)
# print("new_kv_cache: mean: ", new_kv_cache.astype(np.float32).mean().item())
# print("new_kv_cache: max: ", new_kv_cache.astype(np.float32).max().item())
# if self.run_times <= 2:
# print(" == Debug == ")
# logits = result[0]
# new_kv_cache = result[1]
# print("logits shape: ", logits.shape)
# print("logits mean: ", logits.astype(np.float32).mean().item())
# print("logits max: ", logits.astype(np.float32).max().item())
# print("new_kv_cache: shape", new_kv_cache.shape)
# print("new_kv_cache: mean: ", new_kv_cache.astype(np.float32).mean().item())
# print("new_kv_cache: max: ", new_kv_cache.astype(np.float32).max().item())
self.kv_cache.update(1,result[1])
return result[0].reshape(1,1,-1)

0 comments on commit 3c2e592

Please sign in to comment.