-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,143 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import sys | ||
import argparse | ||
from concurrent.futures import ThreadPoolExecutor | ||
from config import InferenceConfig | ||
from utils.inference import Inference | ||
import os | ||
|
||
project_dir = os.path.dirname(os.path.abspath(__file__)) | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--hf_model_dir', | ||
type=str, | ||
help="model and tokenizer path, only support huggingface model", | ||
default=os.path.join(project_dir, "download", "Qwen1_5_0_5B_Chat") | ||
) | ||
parser.add_argument( | ||
"--session_type", | ||
type=str, | ||
default="acl", | ||
help="acl or onnx", | ||
choices=["acl", "onnx"], | ||
) | ||
parser.add_argument( | ||
'--onnx_model_path', | ||
type=str, | ||
help="onnx_model_path", | ||
default=os.path.join(project_dir, "output", "onnx", "qwen1.5_0.5b_chat.onnx") | ||
) | ||
parser.add_argument( | ||
"--om_model_path", | ||
help="mindspore model path", | ||
type=str, | ||
default= os.path.join(project_dir, "output", "model", "qwen1.5_0.5b_chat.om") | ||
) | ||
parser.add_argument( | ||
"--max_input_length", | ||
help="max input length", | ||
type=int, | ||
default=512, | ||
) | ||
|
||
parser.add_argument( | ||
"--max_output_length", | ||
help="max output length (contain input + new token)", | ||
type=int, | ||
default=1024, | ||
) | ||
|
||
args = parser.parse_args() | ||
config = InferenceConfig( | ||
hf_model_dir=args.hf_model_dir, | ||
om_model_path=args.om_model_path, | ||
onnx_model_path=args.onnx_model_path, | ||
session_type=args.session_type, | ||
max_output_length=args.max_output_length, | ||
max_input_length=args.max_input_length, | ||
kv_cache_length=args.max_output_length, | ||
) | ||
infer_engine=Inference(config) | ||
|
||
def inference_cli(): | ||
print("\n欢迎使用Qwen聊天机器人,输入exit或者quit退出,输入clear清空历史记录") | ||
history = [] | ||
while True: | ||
input_text = input("Input: ") | ||
if input_text in ["exit", "quit", "exit()", "quit()"]: | ||
break | ||
if input_text == 'clear': | ||
history = [] | ||
print("Output: 已清理历史对话信息。") | ||
continue | ||
print("Output: ", end='') | ||
response = "" | ||
is_first = True | ||
first_token_lantency, decode_speed = 0, 0 | ||
for ( | ||
new_text, | ||
first_token_lantency, | ||
decode_speed, | ||
total_speed | ||
) in infer_engine.stream_predict(input_text, history=history): | ||
if is_first: | ||
if len(new_text.strip()) == 0: | ||
continue | ||
is_first = False | ||
print(new_text, end='', flush=True) | ||
response += new_text | ||
print("") | ||
print( | ||
"[INFO] first_token_lantency: {:.4f}s,".format(first_token_lantency), | ||
" decode_speed: {:.2f} token/s, ".format(decode_speed), | ||
" total_speed(prefill+decode): {:.2f} token/s".format(total_speed), | ||
) | ||
|
||
history.append({"role": "assistant", "content": response}) | ||
if __name__ == '__main__': | ||
# main() | ||
inference_cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
from transformers.models.qwen2 import Qwen2Config, Qwen2Tokenizer | ||
|
||
|
||
class InferenceConfig: | ||
def __init__( | ||
self, | ||
hf_model_dir: str, | ||
om_model_path: str, | ||
onnx_model_path: str, | ||
session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language | ||
device_id: int = 0, | ||
sampling_method: str = "top_k", | ||
sampling_value: float = 10, | ||
temperature: float = 0.7, | ||
max_input_length: int = 512, # 输入长度的最大数值 | ||
max_output_length: int = 1024, # 输出长度的最大值 | ||
kvcache_method: str = "fixsize", # kv_cache类型,支持basic,fixsize,streamllm,H2O四种,具体可以去kvcache.py查看 | ||
kv_cache_length: int = 1024, # kvcache的最大长度 | ||
cache_format: str = 'huggingface-tensor', # kv_cache的格式 | ||
dtype:str="float16", | ||
): | ||
self.tokenizer_dir = hf_model_dir | ||
self.session_type = session_type | ||
if self.session_type == "acl": | ||
assert os.path.exists(om_model_path), print(om_model_path, "not exists") | ||
elif self.session_type == "onnx": | ||
assert os.path.exists(onnx_model_path), print(onnx_model_path, "not exists") | ||
self.om_model_path = om_model_path | ||
self.onnx_model_path = onnx_model_path | ||
self.device_id = device_id | ||
self.sampling_method = sampling_method | ||
self.sampling_value = sampling_value | ||
self.temperature = temperature | ||
self.max_input_length = max_input_length | ||
self.max_output_length = max_output_length | ||
self.kvcache_method = kvcache_method | ||
self.kv_cache_length = kv_cache_length # max_cache_size | ||
self.cache_format = cache_format | ||
self.dtype = dtype | ||
self.model_config = Qwen2Config.from_pretrained(hf_model_dir) | ||
self.num_hidden_layers = self.model_config.num_hidden_layers # n_layer | ||
self.num_key_value_heads = self.model_config.num_key_value_heads # head_num | ||
self.hidden_size = self.model_config.hidden_size # hidden_dim | ||
self.num_attention_heads = self.model_config.num_attention_heads | ||
self.per_head_dim = self.hidden_size // self.num_attention_heads # head_dim | ||
self.past_key_value_shape = ( | ||
self.num_hidden_layers, | ||
2, | ||
1, | ||
self.num_key_value_heads, | ||
self.kv_cache_length, | ||
self.per_head_dim | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.