Skip to content

Commit

Permalink
update readme and onnx inference demo
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 25, 2024
1 parent 6489df4 commit 9e24de8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@
- demo展示2(演示模型,qwen2-1.5b-instruct,开启动态shape推理, max_prefill_length=8)
![](./image/qwen2-1.5b-instruct.gif)

- demo展示3(演示模型,qwen2-1.5b-instruct,onnx cpu推理,CPU: i9-10900k 10核20线程)
![](./image/qwen2_1.5b_onnx_chat_cpu.png)


##### 步骤3:部署兼容OpenAI的api
- 使用下面的命令直接运行api,`--max_prefill_length`需要和上面编译的时候使用的数值相同。
Expand Down
Binary file added image/qwen2_1.5b_onnx_chat_cpu.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 1 addition & 5 deletions utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ def run(self, input_ids: np.ndarray, show_progress=False):
input_ids = torch.from_numpy(input_ids).long().to(self.device_str)
seq_len = input_ids.shape[-1]
cache, mask, pos_ids = self.kv_cache.get_inputs(seq_len)
# print("input_ids shape/dtype: ", input_ids.shape, input_ids.dtype)
# print("cache shape/dtype: ", cache.shape, cache.dtype)
# print("mask shape/dtype: ", mask.shape, mask.dtype)
# print("pos_ids shape/dtype: ", pos_ids.shape, pos_ids.dtype)
result = self.model(input_ids, mask, pos_ids, cache)
self.kv_cache.update(seq_len, result[1])
return result[0].cpu().detach().numpy()

# onnxruntime-cann is preview, not work now
"""
class CANNOnnxSession(Session):
Expand Down

0 comments on commit 9e24de8

Please sign in to comment.