Skip to content

Commit

Permalink
fixup some bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Tlntin committed Jul 31, 2024
1 parent de419d7 commit 43498bf
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 35 deletions.
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
```bash
git clone https://github.com/Tlntin/qwen-ascend-llm.git
```

2. 下载qwen1.5/qwen2的模型,选择chat模型或者instruct模型,将其放到download文件夹,仅支持huggingface下载的模型,网络不好的可以用镜像站:https://hf-mirror.com/Qwen


Expand All @@ -18,30 +19,35 @@

### 分步骤运行
##### 步骤1:编译模型(以Qwen2-1.5B-Instruct)为例。
1. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数。
1. 除了上面说的CANN环境安装外,还需额外安装一些python模块。
```bash
cd qwen-ascend-llm
pip install -r ./requirements.txt
```
2. 导出onnx,默认kv-cache长度为1024,可以根据自己的内存、显存来设置更大参数。
```bash
python3 export/export_onnx.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \
--kv_cache_length=1024
```

2. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。
3. 验证onnx,返回项目根目录,运行cli_chat.py,测试一下onnx对话是否正常(注意:由于是cpu运行,所以速度较慢,请耐心等待)。
```bash
python3 ./cli_chat.py \
--session_type=onnx \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
--onnx_model_path="./output/onnx/qwen2_1.5b_chat.onnx"
```

3. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。
4. 改变onnx结构,目前导出的Trilu算子和Cast算子有些问题,atc命令无法识别,需要改一下结构。
```bash
python3 export/change_node.py \
--input_model_path="./output/onnx/qwen2_1.5b_chat.onnx" \
--output_model_path="./output/onnx2/qwen2_1.5b_chat.onnx"
```

4. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。
5. 转onnx为om模型, 将修改后的onnx利用atc命令导出到onnx,**注意此处的om_model_path不带`.om`后缀**。运行过程可能会有一些警告,或者子图融合报错,只要结果是提示`success`就说明没啥问题。kv_cache_length长度和第一步导出onnx时的长度保持一致。`--max_prefill_length`为prefill阶段,单次能处理的最大长度,该数值越长则越能降低首字延迟,但是相应的onnx转om的时间也会变长。设置该数值时,一般为2的指数,例如2、4、8、16等等,推理时会利用递归自动匹配合适的prefill长度,例如输入12,会匹配[8, 4]。当前默认数值为8,如果设置为1,则不会开启动态shape推理功能。
```bash
python3 export/onnx2om.py \
--hf_model_dir="./download/Qwen2-1.5B-Instruct" \
Expand Down Expand Up @@ -81,9 +87,10 @@
```bash
# openai_stream_client.py 流式请求,类似打字机效果
# openai_normal_client.py 非流式请求,需要等模型推理完再返回
# openai_function_call.py 测试function_call
# openai_function_call.py 测试function_call,该功能启用时建议增加max_input_length和kv_cache_length的长度。
```

- functional_call demo展示(使用qwen2-1.5b-instruct)![](./image/qwen2-1.5b-instruct-functional-call.jpg)

### 当前功能
- [x] 导出onnx, om模型
Expand Down
2 changes: 1 addition & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def parse_response(response):


# completion mode, not chat mode
def text_complete_last_message(history, stop_words_ids, max_new_tokens): # sampling_config, :
def text_complete_last_message(history, stop_words_ids, sampling_config, max_new_tokens):
im_start = "<|im_start|>"
im_end = "<|im_end|>"
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
Expand Down
42 changes: 26 additions & 16 deletions client/openai_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import random
import json
import asyncio


urllib3.disable_warnings()
Expand Down Expand Up @@ -100,7 +101,8 @@ def get_current_weather(location: str):
if len(location_data) > 0:
location_dict = location_data[0]
city_id = location_dict["id"]
weather_res = weather.get_weather_from_api(city_id)
# 由于输入长度限制,这里只取一天的天气,有条件的可以自己更改。
weather_res = weather.get_weather_from_api(city_id)[:1]
n_day = len(weather_res)
return f"查询到最近{n_day}天的天气。" + json.dumps(weather_res, ensure_ascii=False)
else:
Expand All @@ -109,16 +111,18 @@ def get_current_weather(location: str):
def call_qwen(messages, functions=None):
# print(messages)
if functions:
response = client.chat.completions.create(
model="Qwen", messages=messages, functions=functions
return client.chat.completions.create(
model="Qwen",
messages=messages,
functions=functions,
temperature=0,
)
else:
response = client.chat.completions.create(
model="Qwen", messages=messages
return client.chat.completions.create(
model="Qwen",
messages=messages,
temperature=0
)
# print(response)
# print(response.choices[0].message.content)
return response


def chat(query: str):
Expand Down Expand Up @@ -148,6 +152,7 @@ def chat(query: str):
"content": query,
}
]
print("[INFO] Invoke AI and ask if it need to call the plugin.")
response = call_qwen(messages, functions)
res = response.choices[0].message
message_dict = {
Expand Down Expand Up @@ -178,19 +183,23 @@ def chat(query: str):
had_params = list(function_params.keys())
if len(had_params) != len(require_params):
raise Exception("ERROR, need to do other fill params")


response = eval(function_name)(**function_params)
print("[INFO] will call funtion {} with params {}".format(
function_name, function_params
))
fun_response = eval(function_name)(**function_params)
print("[INFO] call function response is: ", response)
message = {
"role": "function",
"name": function_name,
}
if len(response) > 0:
message["content"] = response
if len(fun_response) > 0:
message["content"] = fun_response
else:
message["content"] = "未找到任何信息"
messages.append(message)
print("[INFO] send function response to AI")
response = call_qwen(messages, functions)
return response
return response


Expand All @@ -200,6 +209,7 @@ def chat(query: str):
print("目前已支持天气查询插件")
print("=" * 20)
query = "北京天气如何?穿短袖会不会冷?"
print("用户输入:", query)
res = chat(query)
print("回答结果:", res.choices[0].message.content)
print("User:", query)
response = chat(query)
res = response.choices[0].message.content
print("ChatBot: ", res)
1 change: 1 addition & 0 deletions client/openai_normal_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
break
if prompt == 'clear':
messages = messages[:1]
print("ChatBot: 已清理历史对话信息。")
continue
messages.append({"role": "user", "content": prompt})
completion = client.chat.completions.create(
Expand Down
1 change: 1 addition & 0 deletions client/openai_stream_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
break
if prompt == 'clear':
messages = messages[:1]
print("ChatBot: 已清理历史对话信息。")
continue
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
Expand Down
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def __init__(
onnx_model_path: str,
session_type: str = "acl", # 支持acl和onnx两种,acl即Ascend C Language
device_id: int = 0,
sampling_method: str = "top_k",
sampling_value: float = 10,
sampling_method: str = "top_p", # 支持 greedy, top_p, top_k
sampling_value: float = 0.8,
temperature: float = 0.7,
max_batch: int = 1,
max_input_length: int = 512, # 输入长度的最大数值
Expand Down
Binary file added image/qwen2-1.5b-instruct-functional-call.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
onnx==1.16.1
onnxruntime==1.18.1
transformers==4.37.0
# onnxruntime-cann==1.18.1
torch==2.1.0
torch-npu==2.1.0.post6
tqdm
fastapi
uvicorn
sse_starlette
sse_starlette==1.6.5
openai
57 changes: 47 additions & 10 deletions utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def stream_predict(
if history is None:
history = []
sampling_value = sampling_config.get("sampling_value", self.sampling_value)
temperature = sampling_config.get("sampling_value", self.temperature)
temperature = sampling_config.get("temperature", self.temperature)
messages = [{"role": "system", "content": system_prompt}]
# print("prompt: ", prompt)
with self.lock:
Expand All @@ -144,6 +144,8 @@ def stream_predict(
input_ids = self.tokenizer(
[text], return_tensors="np"
)["input_ids"].astype(np.int64).reshape(1, -1)
input_ids = input_ids[:, -self.max_input_length:]
print("input_ids shape: ", input_ids.shape)
self.first = False
ids_list = []
text_length = 0
Expand All @@ -158,9 +160,20 @@ def stream_predict(
temp_list = trange(max_output_len, desc="decode")
else:
temp_list = range(max_output_len)
prefill_show_progress = False
for i in temp_list:
prefill_show_progress = (i == 0)
logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0]
if i == 0:
if show_progress:
prefill_show_progress = True
# reset counter
self.session.run_times = 0
self.session.kv_cache.real_kv_size = 0
else:
prefill_show_progress = False
logits = self.session.run(
input_ids,
show_progress=prefill_show_progress
)[0]
input_ids = self.sample_logits(
logits[0][-1:],
self.sampling_method,
Expand Down Expand Up @@ -207,7 +220,7 @@ def predict(
if history is None:
history = []
sampling_value = sampling_config.get("sampling_value", self.sampling_value)
temperature = sampling_config.get("sampling_value", self.temperature)
temperature = sampling_config.get("temperature", self.temperature)
messages = [{"role": "system", "content": system_prompt}]
# print("prompt: ", prompt)
with self.lock:
Expand All @@ -227,6 +240,7 @@ def predict(
input_ids = self.tokenizer(
[text], return_tensors="np"
)["input_ids"].astype(np.int64).reshape(1, -1)
input_ids = input_ids[:, -self.max_input_length:]
self.first = False
ids_list = []
# text_length = 0
Expand All @@ -240,9 +254,20 @@ def predict(
temp_list = trange(max_output_len, desc="decode")
else:
temp_list = range(max_output_len)
prefill_show_progress = False
for i in temp_list:
prefill_show_progress = (i == 0)
logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0]
if i == 0:
if show_progress:
prefill_show_progress = True
# reset counter
self.session.run_times = 0
self.session.kv_cache.real_kv_size = 0
else:
prefill_show_progress = False
logits = self.session.run(
input_ids,
show_progress=prefill_show_progress
)[0]
input_ids = self.sample_logits(
logits[0][-1:],
self.sampling_method,
Expand Down Expand Up @@ -275,19 +300,31 @@ def generate(
show_progress: bool = False,
):
sampling_value = sampling_config.get("sampling_value", self.sampling_value)
temperature = sampling_config.get("sampling_value", self.temperature)
temperature = sampling_config.get("temperature", self.temperature)
self.first = False
ids_list = []
input_ids = input_ids[:, -self.max_input_length:]
input_length = input_ids.shape[1]
max_output_len = self.max_output_length - input_length
max_output_len = min(max_output_len, max_new_tokens)
if show_progress:
temp_list = trange(max_output_len, desc="decode")
else:
temp_list = range(max_output_len)
prefill_show_progress = False
for i in temp_list:
prefill_show_progress = (i == 0)
logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0]
if i == 0:
if show_progress:
prefill_show_progress = True
# reset counter
self.session.run_times = 0
self.session.kv_cache.real_kv_size = 0
else:
prefill_show_progress = False
logits = self.session.run(
input_ids,
show_progress=prefill_show_progress
)[0]
input_ids = self.sample_logits(
logits[0][-1:],
self.sampling_method,
Expand All @@ -302,7 +339,7 @@ def generate(
break
ids_list.append(input_ids[0].item())
text_out = self.tokenizer.decode(ids_list)
print("Debug: ", text_out)
# print("Debug: ", text_out)
# stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"])
self.state['message'] = text_out
with self.lock:
Expand Down
4 changes: 4 additions & 0 deletions utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def run_some(
seq_length: int = 1,
is_dynamic: bool = False
):
# print(
# "self.run_times: ", self.run_times,
# "real kv size: ", self.kv_cache.real_kv_size
# )
self.run_times += seq_length
cache, mask, pos_ids = self.kv_cache.get_inputs(seq_length)
result:List[np.ndarray] = self.model.inference(
Expand Down

0 comments on commit 43498bf

Please sign in to comment.