diff --git a/README.md b/README.md index d58de6a..78b3479 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ ```bash git clone https://github.com/Tlntin/qwen-ascend-llm.git ``` + 2. 下载qwen1.5/qwen2的模型,选择chat模型或者instruct模型,将其放到download文件夹,仅支持huggingface下载的模型,网络不好的可以用镜像站:https://hf-mirror.com/Qwen @@ -18,7 +19,12 @@ ### 分步骤运行 ##### 步骤1:编译模型(以Qwen2-1.5B-Instruct)为例。 -1. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数。 +1. 除了上面说的CANN环境安装外,还需额外安装一些python模块。 + ```bash + cd qwen-ascend-llm + pip install -r ./requirements.txt + ``` +2. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数。 ```bash python3 export/export_onnx.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ @@ -26,7 +32,7 @@ --kv_cache_length=1024 ``` -2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。 +3. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。 ```bash python3 ./cli_chat.py \ --session_type=onnx \ @@ -34,14 +40,14 @@ --onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" ``` -3. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。 +4. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。 ```bash python3 export/change_node.py \ --input_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \ --output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx" ``` -4. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。 +5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。 ```bash python3 export/onnx2om.py \ --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ @@ -81,9 +87,10 @@ ```bash # openai_stream_client.py 流式请求,类似打字机效果 # openai_normal_client.py 非流式请求,需要等模型推理完再返回 - # openai_function_call.py 测试function_call + # openai_function_call.py 测试function_call,该功能启用时建议增加max_input_length和kv_cache_length的长度。 ``` +- functional_call demo展示(使用qwen2-1.5b-instruct)![](./image/qwen2-1.5b-instruct-functional-call.jpg) ### 当前功能 - [x] 导出onnx, om模型 diff --git a/api.py b/api.py index 2b5faf2..ddb8278 100644 --- a/api.py +++ b/api.py @@ -350,7 +350,7 @@ def parse_response(response): # completion mode, not chat mode -def text_complete_last_message(history, stop_words_ids, max_new_tokens): # sampling_config, : +def text_complete_last_message(history, stop_words_ids, sampling_config, max_new_tokens): im_start = "<|im_start|>" im_end = "<|im_end|>" prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" diff --git a/client/openai_function_call.py b/client/openai_function_call.py index 574cadd..fda27d4 100644 --- a/client/openai_function_call.py +++ b/client/openai_function_call.py @@ -5,6 +5,7 @@ import time import random import json +import asyncio urllib3.disable_warnings() @@ -100,7 +101,8 @@ def get_current_weather(location: str): if len(location_data) > 0: location_dict = location_data[0] city_id = location_dict["id"] - weather_res = weather.get_weather_from_api(city_id) + # 由于输入长度限制,这里只取一天的天气,有条件的可以自己更改。 + weather_res = weather.get_weather_from_api(city_id)[:1] n_day = len(weather_res) return f"查询到最近{n_day}天的天气。" + json.dumps(weather_res, ensure_ascii=False) else: @@ -109,16 +111,18 @@ def get_current_weather(location: str): def call_qwen(messages, functions=None): # print(messages) if functions: - response = client.chat.completions.create( - model="Qwen", messages=messages, functions=functions + return client.chat.completions.create( + model="Qwen", + messages=messages, + functions=functions, + temperature=0, ) else: - response = client.chat.completions.create( - model="Qwen", messages=messages + return client.chat.completions.create( + model="Qwen", + messages=messages, + temperature=0 ) - # print(response) - # print(response.choices[0].message.content) - return response def chat(query: str): @@ -148,6 +152,7 @@ def chat(query: str): "content": query, } ] + print("[INFO] Invoke AI and ask if it need to call the plugin.") response = call_qwen(messages, functions) res = response.choices[0].message message_dict = { @@ -178,19 +183,23 @@ def chat(query: str): had_params = list(function_params.keys()) if len(had_params) != len(require_params): raise Exception("ERROR, need to do other fill params") - - - response = eval(function_name)(**function_params) + print("[INFO] will call funtion {} with params {}".format( + function_name, function_params + )) + fun_response = eval(function_name)(**function_params) + print("[INFO] call function response is: ", response) message = { "role": "function", "name": function_name, } - if len(response) > 0: - message["content"] = response + if len(fun_response) > 0: + message["content"] = fun_response else: message["content"] = "未找到任何信息" messages.append(message) + print("[INFO] send function response to AI") response = call_qwen(messages, functions) + return response return response @@ -200,6 +209,7 @@ def chat(query: str): print("目前已支持天气查询插件") print("=" * 20) query = "北京天气如何?穿短袖会不会冷?" -print("用户输入:", query) -res = chat(query) -print("回答结果:", res.choices[0].message.content) +print("User:", query) +response = chat(query) +res = response.choices[0].message.content +print("ChatBot: ", res) \ No newline at end of file diff --git a/client/openai_normal_client.py b/client/openai_normal_client.py index f11c745..3ff73b0 100644 --- a/client/openai_normal_client.py +++ b/client/openai_normal_client.py @@ -13,6 +13,7 @@ break if prompt == 'clear': messages = messages[:1] + print("ChatBot: 已清理历史对话信息。") continue messages.append({"role": "user", "content": prompt}) completion = client.chat.completions.create( diff --git a/client/openai_stream_client.py b/client/openai_stream_client.py index 191be54..eaa0ebb 100644 --- a/client/openai_stream_client.py +++ b/client/openai_stream_client.py @@ -14,6 +14,7 @@ break if prompt == 'clear': messages = messages[:1] + print("ChatBot: 已清理历史对话信息。") continue messages.append({"role": "user", "content": prompt}) response = client.chat.completions.create( diff --git a/config.py b/config.py index 4e9348e..ff31a45 100644 --- a/config.py +++ b/config.py @@ -10,8 +10,8 @@ def __init__( onnx_model_path: str, session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language device_id: int = 0, - sampling_method: str = "top_k", - sampling_value: float = 10, + sampling_method: str = "top_p", # 支持 greedy, top_p, top_k + sampling_value: float = 0.8, temperature: float = 0.7, max_batch: int = 1, max_input_length: int = 512, # 输入长度的最大数值 diff --git a/image/qwen2-1.5b-instruct-functional-call.jpg b/image/qwen2-1.5b-instruct-functional-call.jpg new file mode 100644 index 0000000..6dc5388 Binary files /dev/null and b/image/qwen2-1.5b-instruct-functional-call.jpg differ diff --git a/requirements.txt b/requirements.txt index 9ef3435..378afc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ onnx==1.16.1 onnxruntime==1.18.1 +transformers==4.37.0 # onnxruntime-cann==1.18.1 torch==2.1.0 torch-npu==2.1.0.post6 tqdm fastapi uvicorn -sse_starlette +sse_starlette==1.6.5 openai \ No newline at end of file diff --git a/utils/inference.py b/utils/inference.py index 469cae0..d77dc25 100644 --- a/utils/inference.py +++ b/utils/inference.py @@ -124,7 +124,7 @@ def stream_predict( if history is None: history = [] sampling_value = sampling_config.get("sampling_value", self.sampling_value) - temperature = sampling_config.get("sampling_value", self.temperature) + temperature = sampling_config.get("temperature", self.temperature) messages = [{"role": "system", "content": system_prompt}] # print("prompt: ", prompt) with self.lock: @@ -144,6 +144,8 @@ def stream_predict( input_ids = self.tokenizer( [text], return_tensors="np" )["input_ids"].astype(np.int64).reshape(1, -1) + input_ids = input_ids[:, -self.max_input_length:] + print("input_ids shape: ", input_ids.shape) self.first = False ids_list = [] text_length = 0 @@ -158,9 +160,20 @@ def stream_predict( temp_list = trange(max_output_len, desc="decode") else: temp_list = range(max_output_len) + prefill_show_progress = False for i in temp_list: - prefill_show_progress = (i == 0) - logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0] + if i == 0: + if show_progress: + prefill_show_progress = True + # reset counter + self.session.run_times = 0 + self.session.kv_cache.real_kv_size = 0 + else: + prefill_show_progress = False + logits = self.session.run( + input_ids, + show_progress=prefill_show_progress + )[0] input_ids = self.sample_logits( logits[0][-1:], self.sampling_method, @@ -207,7 +220,7 @@ def predict( if history is None: history = [] sampling_value = sampling_config.get("sampling_value", self.sampling_value) - temperature = sampling_config.get("sampling_value", self.temperature) + temperature = sampling_config.get("temperature", self.temperature) messages = [{"role": "system", "content": system_prompt}] # print("prompt: ", prompt) with self.lock: @@ -227,6 +240,7 @@ def predict( input_ids = self.tokenizer( [text], return_tensors="np" )["input_ids"].astype(np.int64).reshape(1, -1) + input_ids = input_ids[:, -self.max_input_length:] self.first = False ids_list = [] # text_length = 0 @@ -240,9 +254,20 @@ def predict( temp_list = trange(max_output_len, desc="decode") else: temp_list = range(max_output_len) + prefill_show_progress = False for i in temp_list: - prefill_show_progress = (i == 0) - logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0] + if i == 0: + if show_progress: + prefill_show_progress = True + # reset counter + self.session.run_times = 0 + self.session.kv_cache.real_kv_size = 0 + else: + prefill_show_progress = False + logits = self.session.run( + input_ids, + show_progress=prefill_show_progress + )[0] input_ids = self.sample_logits( logits[0][-1:], self.sampling_method, @@ -275,9 +300,10 @@ def generate( show_progress: bool = False, ): sampling_value = sampling_config.get("sampling_value", self.sampling_value) - temperature = sampling_config.get("sampling_value", self.temperature) + temperature = sampling_config.get("temperature", self.temperature) self.first = False ids_list = [] + input_ids = input_ids[:, -self.max_input_length:] input_length = input_ids.shape[1] max_output_len = self.max_output_length - input_length max_output_len = min(max_output_len, max_new_tokens) @@ -285,9 +311,20 @@ def generate( temp_list = trange(max_output_len, desc="decode") else: temp_list = range(max_output_len) + prefill_show_progress = False for i in temp_list: - prefill_show_progress = (i == 0) - logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0] + if i == 0: + if show_progress: + prefill_show_progress = True + # reset counter + self.session.run_times = 0 + self.session.kv_cache.real_kv_size = 0 + else: + prefill_show_progress = False + logits = self.session.run( + input_ids, + show_progress=prefill_show_progress + )[0] input_ids = self.sample_logits( logits[0][-1:], self.sampling_method, @@ -302,7 +339,7 @@ def generate( break ids_list.append(input_ids[0].item()) text_out = self.tokenizer.decode(ids_list) - print("Debug: ", text_out) + # print("Debug: ", text_out) # stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"]) self.state['message'] = text_out with self.lock: diff --git a/utils/session.py b/utils/session.py index e76b361..2b4f81f 100644 --- a/utils/session.py +++ b/utils/session.py @@ -181,6 +181,10 @@ def run_some( seq_length: int = 1, is_dynamic: bool = False ): + # print( + # "self.run_times: ", self.run_times, + # "real kv size: ", self.kv_cache.real_kv_size + # ) self.run_times += seq_length cache, mask, pos_ids = self.kv_cache.get_inputs(seq_length) result:List[np.ndarray] = self.model.inference(