From fb8640dda58f13e2302ad97e6835527a1b7a934b Mon Sep 17 00:00:00 2001 From: Tlntin Date: Thu, 24 Oct 2024 22:39:07 +0800 Subject: [PATCH] support pytorch session type --- cli_chat.py | 12 +++++++----- config.py | 22 +++++++++++----------- requirements.txt | 2 +- utils/inference.py | 2 +- utils/session.py | 16 +++++++++++----- 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/cli_chat.py b/cli_chat.py index e2f7318..45e3721 100644 --- a/cli_chat.py +++ b/cli_chat.py @@ -18,12 +18,12 @@ def parser_args(): parser.add_argument( "--session_type", type=str, - default="acl", + default="pytorch", help="acl or onnx", choices=["acl", "onnx", "pytorch"], ) parser.add_argument( - "--dtype" , + "--dtype", type=str, help="support float16/float32, if use CPU, only support fp32", choices=["float16", "float32"], @@ -126,8 +126,9 @@ def inference_cli(): " decode_speed: {:.2f} token/s, ".format(decode_speed), " total_speed(prefill+decode): {:.2f} token/s".format(total_speed), ) - history.append([input_text, response]) + + if __name__ == '__main__': args = parser_args() max_prefill_log2 = int(math.log2(args.max_prefill_length)) @@ -143,7 +144,8 @@ def inference_cli(): max_input_length=args.max_input_length, kv_cache_length=args.max_output_length, max_prefill_length=max_prefill_length, - dtype=args.dtype + dtype=args.dtype, + torch_dtype=args.torch_dtype ) # main() - inference_cli() \ No newline at end of file + inference_cli() diff --git a/config.py b/config.py index e5f0626..6eae3a1 100644 --- a/config.py +++ b/config.py @@ -9,22 +9,22 @@ def __init__( hf_model_dir: str, om_model_path: str, onnx_model_path: str, - cpu_thread: int = 4, # CPU线程数 - session_type: str = "acl", # 支持acl和onnx, pytorch三种,acl即Ascend C Language + cpu_thread: int = 4, # CPU线程数 + session_type: str = "acl", # 支持acl和onnx, pytorch三种,acl即Ascend C Language device_id: int = 0, - sampling_method: str = "top_p", # 支持 greedy, top_p, top_k + sampling_method: str = "top_p", # 支持 greedy, top_p, top_k sampling_value: float = 0.8, temperature: float = 0.7, max_batch: int = 1, - max_input_length: int = 1024, # 输入长度的最大数值 - max_output_length: int = 2048, # 输出长度的最大值 - max_prefill_length: int = 1, # prefile阶段,单次最大推理长度 - kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看 - kv_cache_length: int = 2048, # kvcache的最大长度 - cache_format: str = 'huggingface-tensor', # kv_cache的格式 - dtype:str="float16", + max_input_length: int = 1024, # 输入长度的最大数值 + max_output_length: int = 2048, # 输出长度的最大值 + max_prefill_length: int = 1, # prefile阶段,单次最大推理长度 + kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看 + kv_cache_length: int = 2048, # kvcache的最大长度 + cache_format: str = 'huggingface-tensor', # kv_cache的格式 + dtype: str = "float16", torch_dtype: str = "float16", - device_str = "cpu", + device_str: str = "cpu", ): self.tokenizer_dir = hf_model_dir self.session_type = session_type diff --git a/requirements.txt b/requirements.txt index 1bfedc0..14d506f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ onnxruntime==1.18.1 transformers==4.37.0 # onnxruntime-cann==1.18.1 torch==2.1.0 -torch-npu==2.1.0.post6 +# torch-npu==2.1.0.post6 tqdm fastapi uvicorn diff --git a/utils/inference.py b/utils/inference.py index e7ddb3d..7f4ec30 100644 --- a/utils/inference.py +++ b/utils/inference.py @@ -244,7 +244,7 @@ def predict( messages = [{"role": "system", "content": system_prompt}] # print("prompt: ", prompt) with self.lock: - self.state['isEnd'],self.state['message'] = False,"" + self.state['isEnd'], self.state['message'] = False,"" if prompt == "": return for (use_msg, bot_msg) in history: diff --git a/utils/session.py b/utils/session.py index cd8e088..b762f79 100644 --- a/utils/session.py +++ b/utils/session.py @@ -1,3 +1,5 @@ +import torch + from config import InferenceConfig from utils.kvcache import create_kv_cache import numpy as np @@ -36,6 +38,7 @@ def reset(self): def rollback(self,seq_len): self.kv_cache.rollback(seq_len) + class OnnxSession(Session): def __init__(self,config:InferenceConfig)->None: super().__init__(config) @@ -62,11 +65,12 @@ def run(self, input_ids:np.ndarray, show_progress=False): "past_key_values": cache, "position_ids": pos_ids, }) - self.kv_cache.update(seq_len,result[1]) + self.kv_cache.update(seq_len, result[1]) return result[0] + class PyTorchSession(Session): - def __init__(self, config:InferenceConfig)->None: + def __init__(self, config:InferenceConfig) -> None: super().__init__(config) self.kv_cache = create_kv_cache(config) from export.modeling_qwen2 import Qwen2ForCausalLM @@ -76,8 +80,10 @@ def __init__(self, config:InferenceConfig)->None: torch_dtype=config.torch_dtype ).to(config.device_str) - def run(self, input_ids:np.ndarray, show_progress=False): - seq_len=input_ids.shape[-1] + def run(self, input_ids: np.ndarray, show_progress=False): + if isinstance(input_ids, np.ndarray): + input_ids = torch.from_numpy(input_ids).long().to(self.device_str) + seq_len = input_ids.shape[-1] cache, mask, pos_ids = self.kv_cache.get_inputs(seq_len) # print("input_ids shape/dtype: ", input_ids.shape, input_ids.dtype) # print("cache shape/dtype: ", cache.shape, cache.dtype) @@ -85,7 +91,7 @@ def run(self, input_ids:np.ndarray, show_progress=False): # print("pos_ids shape/dtype: ", pos_ids.shape, pos_ids.dtype) result = self.model(input_ids, mask, pos_ids, cache) self.kv_cache.update(seq_len, result[1]) - return result[0] + return result[0].cpu().detach().numpy() # onnxruntime-cann is preview, not work now """