Skip to content

Commit

Permalink
add requirements and test code with cann onnxruntime
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 28, 2024
1 parent d760724 commit ea4590a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
onnx==1.16.1
onnxruntime==1.18.1
# onnxruntime-cann==1.18.1
torch==2.1.0
torch-npu==2.1.0.post6
51 changes: 50 additions & 1 deletion utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import sys
from utils.engine import ACLModel, init_resource, destroy_resource
import onnxruntime as ort

class Session:
def __init__(self, config: InferenceConfig) -> None:
Expand All @@ -30,7 +31,7 @@ def reset(self):

def rollback(self,seq_len):
self.kv_cache.rollback(seq_len)

class OnnxSession(Session):
def __init__(self,config:InferenceConfig)->None:
super().__init__(config)
Expand All @@ -55,6 +56,54 @@ def run(self, input_ids:np.ndarray):
})
self.kv_cache.update(seq_len,result[1])
return result

# onnxruntime-cann is preview, not work now
"""
class CANNOnnxSession(Session):
def __init__(self,config:InferenceConfig)->None:
super().__init__(config)
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
# options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
self.llm_session = ort.InferenceSession(
config.onnx_model_path,
sess_options=options,
providers=[
(
"CANNExecutionProvider",
{
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"npu_mem_limit": 20 * 1024 * 1024 * 1024,
"op_select_impl_mode": "high_performance",
"optypelist_for_implmode": "Gelu",
"enable_cann_graph": True
},
),
"CPUExecutionProvider",
]
)
def run(self, input_ids:np.ndarray):
seq_len=input_ids.shape[-1]
past_key_values, attention_mask, position_ids = self.kv_cache.get_inputs(seq_len)
input_ids_cann = ort.OrtValue.ortvalue_from_numpy(input_ids, device_type="cann", device_id=0)
attention_mask_cann = ort.OrtValue.ortvalue_from_numpy(attention_mask, device_type="cann", device_id=0)
position_ids_cann = ort.OrtValue.ortvalue_from_numpy(position_ids, device_type="cann", device_id=0)
past_key_values_cann = ort.OrtValue.ortvalue_from_numpy(past_key_values, device_type="cann", device_id=0)
io_binding = self.llm_session.io_binding()
io_binding.bind_ortvalue_input(name="input_ids", ortvalue=input_ids_cann)
io_binding.bind_ortvalue_input(name="attention_mask", ortvalue=attention_mask_cann)
io_binding.bind_ortvalue_input(name="position_ids", ortvalue=position_ids_cann)
io_binding.bind_ortvalue_input(name="past_key_values", ortvalue=past_key_values_cann)
io_binding.bind_output("logits", device_type="cann", device_id=0)
io_binding.bind_output("out_key_values", device_type="cann", device_id=0)
self.llm_session.run_with_iobinding(io_binding)
logitsts = io_binding.get_outputs()[0].numpy()
new_kv_cache = io_binding.get_outputs()[1].numpy()
self.kv_cache.update(seq_len, new_kv_cache)
return (logitsts, new_kv_cache)
"""

class AclSession(Session):
context = None
Expand Down

0 comments on commit ea4590a

Please sign in to comment.