Skip to content

Commit

Permalink
support pytorch session type
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Oct 24, 2024
1 parent b6b11e6 commit fb8640d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 23 deletions.
12 changes: 7 additions & 5 deletions cli_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def parser_args():
parser.add_argument(
"--session_type",
type=str,
default="acl",
default="pytorch",
help="acl or onnx",
choices=["acl", "onnx", "pytorch"],
)
parser.add_argument(
"--dtype" ,
"--dtype",
type=str,
help="support float16/float32, if use CPU, only support fp32",
choices=["float16", "float32"],
Expand Down Expand Up @@ -126,8 +126,9 @@ def inference_cli():
" decode_speed: {:.2f} token/s, ".format(decode_speed),
" total_speed(prefill+decode): {:.2f} token/s".format(total_speed),
)

history.append([input_text, response])


if __name__ == '__main__':
args = parser_args()
max_prefill_log2 = int(math.log2(args.max_prefill_length))
Expand All @@ -143,7 +144,8 @@ def inference_cli():
max_input_length=args.max_input_length,
kv_cache_length=args.max_output_length,
max_prefill_length=max_prefill_length,
dtype=args.dtype
dtype=args.dtype,
torch_dtype=args.torch_dtype
)
# main()
inference_cli()
inference_cli()
22 changes: 11 additions & 11 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ def __init__(
hf_model_dir: str,
om_model_path: str,
onnx_model_path: str,
cpu_thread: int = 4, # CPU线程数
session_type: str = "acl", # 支持acl和onnx, pytorch三种,acl即Ascend C Language
cpu_thread: int = 4, # CPU线程数
session_type: str = "acl", # 支持acl和onnx, pytorch三种,acl即Ascend C Language
device_id: int = 0,
sampling_method: str = "top_p", # 支持 greedy, top_p, top_k
sampling_method: str = "top_p", # 支持 greedy, top_p, top_k
sampling_value: float = 0.8,
temperature: float = 0.7,
max_batch: int = 1,
max_input_length: int = 1024, # 输入长度的最大数值
max_output_length: int = 2048, # 输出长度的最大值
max_prefill_length: int = 1, # prefile阶段,单次最大推理长度
kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看
kv_cache_length: int = 2048, # kvcache的最大长度
cache_format: str = 'huggingface-tensor', # kv_cache的格式
dtype:str="float16",
max_input_length: int = 1024, # 输入长度的最大数值
max_output_length: int = 2048, # 输出长度的最大值
max_prefill_length: int = 1, # prefile阶段,单次最大推理长度
kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看
kv_cache_length: int = 2048, # kvcache的最大长度
cache_format: str = 'huggingface-tensor', # kv_cache的格式
dtype: str = "float16",
torch_dtype: str = "float16",
device_str = "cpu",
device_str: str = "cpu",
):
self.tokenizer_dir = hf_model_dir
self.session_type = session_type
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ onnxruntime==1.18.1
transformers==4.37.0
# onnxruntime-cann==1.18.1
torch==2.1.0
torch-npu==2.1.0.post6
# torch-npu==2.1.0.post6
tqdm
fastapi
uvicorn
Expand Down
2 changes: 1 addition & 1 deletion utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def predict(
messages = [{"role": "system", "content": system_prompt}]
# print("prompt: ", prompt)
with self.lock:
self.state['isEnd'],self.state['message'] = False,""
self.state['isEnd'], self.state['message'] = False,""
if prompt == "":
return
for (use_msg, bot_msg) in history:
Expand Down
16 changes: 11 additions & 5 deletions utils/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from config import InferenceConfig
from utils.kvcache import create_kv_cache
import numpy as np
Expand Down Expand Up @@ -36,6 +38,7 @@ def reset(self):
def rollback(self,seq_len):
self.kv_cache.rollback(seq_len)


class OnnxSession(Session):
def __init__(self,config:InferenceConfig)->None:
super().__init__(config)
Expand All @@ -62,11 +65,12 @@ def run(self, input_ids:np.ndarray, show_progress=False):
"past_key_values": cache,
"position_ids": pos_ids,
})
self.kv_cache.update(seq_len,result[1])
self.kv_cache.update(seq_len, result[1])
return result[0]


class PyTorchSession(Session):
def __init__(self, config:InferenceConfig)->None:
def __init__(self, config:InferenceConfig) -> None:
super().__init__(config)
self.kv_cache = create_kv_cache(config)
from export.modeling_qwen2 import Qwen2ForCausalLM
Expand All @@ -76,16 +80,18 @@ def __init__(self, config:InferenceConfig)->None:
torch_dtype=config.torch_dtype
).to(config.device_str)

def run(self, input_ids:np.ndarray, show_progress=False):
seq_len=input_ids.shape[-1]
def run(self, input_ids: np.ndarray, show_progress=False):
if isinstance(input_ids, np.ndarray):
input_ids = torch.from_numpy(input_ids).long().to(self.device_str)
seq_len = input_ids.shape[-1]
cache, mask, pos_ids = self.kv_cache.get_inputs(seq_len)
# print("input_ids shape/dtype: ", input_ids.shape, input_ids.dtype)
# print("cache shape/dtype: ", cache.shape, cache.dtype)
# print("mask shape/dtype: ", mask.shape, mask.dtype)
# print("pos_ids shape/dtype: ", pos_ids.shape, pos_ids.dtype)
result = self.model(input_ids, mask, pos_ids, cache)
self.kv_cache.update(seq_len, result[1])
return result[0]
return result[0].cpu().detach().numpy()

# onnxruntime-cann is preview, not work now
"""
Expand Down

0 comments on commit fb8640d

Please sign in to comment.