diff --git a/README.md b/README.md index ffa67c7..d58de6a 100644 --- a/README.md +++ b/README.md @@ -68,13 +68,30 @@ ![](./image/qwen2-1.5b-instruct.gif) +##### 步骤3:部署兼容OpenAI的api +- 使用下面的命令直接运行api,`--max_prefill_length`需要和上面编译的时候使用的数值相同。 + ```bash + python3 ./api.py \ + --hf_model_dir="./download/Qwen2-1.5B-Instruct" \ + --om_model_path="./output/model/qwen2_1.5b_chat.om" \ + --max_prefill_length=8 + ``` + +- 进入client目录,可以运行里面的文件请求服务端。 + ```bash + # openai_stream_client.py 流式请求,类似打字机效果 + # openai_normal_client.py 非流式请求,需要等模型推理完再返回 + # openai_function_call.py 测试function_call + ``` + + ### 当前功能 - [x] 导出onnx, om模型 - [x] 模型推理,支持onnx推理(仅支持CPU)。 - [x] 模型推理,支持CANN推理。 - [x] CANN推理时使用动态shape推理以降低首字延迟。 - [x] 流式传输 -- [ ] 兼容OpenAI的api搭建 -- [ ] 支持functional call +- [x] 兼容OpenAI的api搭建 +- [x] 支持functional call - [ ] 支持模型量化,如weight only, smooth quant等 - [ ] 支持Docker快速部署 \ No newline at end of file diff --git a/api.py b/api.py new file mode 100644 index 0000000..2b5faf2 --- /dev/null +++ b/api.py @@ -0,0 +1,551 @@ +import json +import numpy as np +import time +from typing import List, Literal, Optional, Union, Dict +import uvicorn +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +from sse_starlette.sse import EventSourceResponse +from cli_chat import parser_args +from config import InferenceConfig +import copy +import math +import re +from utils.inference import Inference + + +def _gc(forced: bool = False): + import gc + + gc.collect() + # if torch.cuda.is_available(): + # torch.cuda.empty_cache() + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + _gc(forced=True) + +app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +args = parser_args() +max_prefill_log2 = int(math.log2(args.max_prefill_length)) +max_prefill_length = 2 ** max_prefill_log2 +config = InferenceConfig( + hf_model_dir=args.hf_model_dir, + om_model_path=args.om_model_path, + onnx_model_path=args.onnx_model_path, + session_type=args.session_type, + max_batch=args.max_batch, + max_output_length=args.max_output_length, + max_input_length=args.max_input_length, + kv_cache_length=args.max_output_length, + max_prefill_length=max_prefill_length, +) + +# init Inference +infer_engine = Inference(config) +show_progress=True + + +@app.get("/") +async def root(): + return "Hello! This is QWen-Chat-7B API." + + +class Data(BaseModel): + query: str + system: str = "You are a helpful assistant." + history: List[List[str]] = [], + max_input_length: Optional[int] = config.max_input_length + max_new_tokens: Optional[int] = config.max_output_length + temperature: Optional[float] = config.temperature + + + +# --- Compatible with OpenAI ChatGPT --- # +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class ChatMessage(BaseModel): + role: Literal["user", "assistant", "system", "function"] + content: Optional[str] + function_call: Optional[Dict] = None + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = "" + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + functions: Optional[List[Dict]] = None + temperature: Optional[float] = config.temperature + top_p: Optional[float] = config.sampling_value + max_tokens: Optional[int] = config.max_output_length + stream: Optional[bool] = False + stop: Optional[List[str]] = None + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length", "function_call"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[ + Union[ + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ] + ] + created: Optional[int] = Field( + default_factory=lambda: int(time.time()) + ) + + +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + global model_args + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + +def add_extra_stop_words(stop_words): + if stop_words: + _stop_words = [] + _stop_words.extend(stop_words) + for x in stop_words: + s = x.lstrip("\n") + if s and (s not in _stop_words): + _stop_words.append(s) + return _stop_words + return stop_words + + +TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" + +REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: + +{tools_text} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tools_name_text}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin!""" + +_TEXT_COMPLETION_CMD = object() + + +def trim_stop_words(response, stop_words): + if stop_words: + for stop in stop_words: + idx = response.find(stop) + if idx != -1: + response = response[:idx] + return response + + +def parse_messages(messages, functions): + if all(m.role != "user" for m in messages): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting at least one user message.", + ) + + messages = copy.deepcopy(messages) + default_system = "You are a helpful assistant." + system = "" + if messages[0].role == "system": + system = messages.pop(0).content.lstrip("\n").rstrip() + if system == default_system: + system = "" + + if functions: + tools_text = [] + tools_name_text = [] + for func_info in functions: + name = func_info.get("name", "") + name_m = func_info.get("name_for_model", name) + name_h = func_info.get("name_for_human", name) + desc = func_info.get("description", "") + desc_m = func_info.get("description_for_model", desc) + tool = TOOL_DESC.format( + name_for_model=name_m, + name_for_human=name_h, + # Hint: You can add the following format requirements in description: + # "Format the arguments as a JSON object." + # "Enclose the code within triple backticks (`) at the beginning and end of the code." + description_for_model=desc_m, + parameters=json.dumps(func_info["parameters"], ensure_ascii=False), + ) + tools_text.append(tool) + tools_name_text.append(name_m) + tools_text = "\n\n".join(tools_text) + tools_name_text = ", ".join(tools_name_text) + system += "\n\n" + REACT_INSTRUCTION.format( + tools_text=tools_text, + tools_name_text=tools_name_text, + ) + system = system.lstrip("\n").rstrip() + + dummy_thought = { + "en": "\nThought: I now know the final answer.\nFinal answer: ", + "zh": "\nThought: 我会作答了。\nFinal answer: ", + } + + _messages = messages + messages = [] + for m_idx, m in enumerate(_messages): + role, content, func_call = m.role, m.content, m.function_call + if content: + content = content.lstrip("\n").rstrip() + if role == "function": + if (len(messages) == 0) or (messages[-1].role != "assistant"): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role assistant before role function.", + ) + messages[-1].content += f"\nObservation: {content}" + if m_idx == len(_messages) - 1: + messages[-1].content += "\nThought:" + elif role == "assistant": + if len(messages) == 0: + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role user before role assistant.", + ) + last_msg = messages[-1].content + last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 + if func_call is None: + if functions: + content = dummy_thought["zh" if last_msg_has_zh else "en"] + content + else: + f_name, f_args = func_call["name"], func_call["arguments"] + if not content: + if last_msg_has_zh: + content = f"Thought: 我可以使用 {f_name} API。" + else: + content = f"Thought: I can use {f_name}." + content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}" + if messages[-1].role == "user": + messages.append( + ChatMessage(role="assistant", content=content.lstrip("\n").rstrip()) + ) + else: + messages[-1].content += content + elif role == "user": + messages.append( + ChatMessage(role="user", content=content.lstrip("\n").rstrip()) + ) + else: + raise HTTPException( + status_code=400, detail=f"Invalid request: Incorrect role {role}." + ) + + query = _TEXT_COMPLETION_CMD + if messages[-1].role == "user": + query = messages[-1].content + messages = messages[:-1] + + if len(messages) % 2 != 0: + print(376) + raise HTTPException(status_code=400, detail="Invalid request") + + history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] + for i in range(0, len(messages), 2): + if messages[i].role == "user" and messages[i + 1].role == "assistant": + usr_msg = messages[i].content.lstrip("\n").rstrip() + bot_msg = messages[i + 1].content.lstrip("\n").rstrip() + if system and (i == len(messages) - 2): + usr_msg = f"{system}\n\nQuestion: {usr_msg}" + system = "" + for t in dummy_thought.values(): + t = t.lstrip("\n") + if bot_msg.startswith(t) and ("\nAction: " in bot_msg): + bot_msg = bot_msg[len(t) :] + history.append([usr_msg, bot_msg]) + else: + raise HTTPException( + status_code=400, + detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", + ) + if system: + assert query is not _TEXT_COMPLETION_CMD + query = f"{system}\n\nQuestion: {query}" + return query, history + +def parse_response(response): + func_name, func_args = "", "" + i = response.rfind("\nAction:") + j = response.rfind("\nAction Input:") + k = response.rfind("\nObservation:") + if 0 <= i < j: # If the text has `Action` and `Action input`, + if k < j: # but does not contain `Observation`, + # then it is likely that `Observation` is omitted by the LLM, + # because the output text may have discarded the stop word. + response = response.rstrip() + "\nObservation:" # Add it back. + k = response.rfind("\nObservation:") + func_name = response[i + len("\nAction:") : j].strip() + func_args = response[j + len("\nAction Input:") : k].strip() + if func_name: + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage( + role="assistant", + content=response[:i], + function_call={"name": func_name, "arguments": func_args}, + ), + finish_reason="function_call", + ) + return choice_data + z = response.rfind("\nFinal Answer: ") + if z >= 0: + response = response[z + len("\nFinal Answer: ") :] + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + return choice_data + + +# completion mode, not chat mode +def text_complete_last_message(history, stop_words_ids, max_new_tokens): # sampling_config, : + im_start = "<|im_start|>" + im_end = "<|im_end|>" + prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" + for i, (query, response) in enumerate(history): + query = query.lstrip("\n").rstrip() + response = response.lstrip("\n").rstrip() + prompt += f"\n{im_start}user\n{query}{im_end}" + prompt += f"\n{im_start}assistant\n{response}{im_end}" + prompt = prompt[: -len(im_end)] + input_ids = infer_engine.tokenizer( + [prompt], return_tensors="np" + )["input_ids"].astype(np.int64).reshape(1, -1) + # _stop_words_ids = [infer_engine.tokenizer.encode(im_end)] + # if stop_words_ids: + # for s in stop_words_ids[0]: + # _stop_words_ids[0].append(s) + + # stop_words_ids = torch.tensor(_stop_words_ids, dtype=torch.int32, device="cuda") + # input_lengths=torch.tensor([input_ids.shape[-1]], dtype=torch.int32, device="cuda") + # output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0] + output = infer_engine.generate( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + sampling_config=sampling_config, + show_progress=show_progress, + ) + # assert output.startswith(prompt) + # output = output[len(prompt) :] + # output = trim_stop_words(output, ["<|endoftext|>", im_end]) + # print(f"\n{prompt}\n\n{output}\n") + return output + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + # print("Debug, top_p: ", request.top_p) + # print("Debug, temperature: ", request.temperature) + # print("Debug, max_tokens: ", request.max_tokens) + sampling_config = { + "top": config.sampling_value, + "temperature": config.temperature, + } + if request.top_p is not None: + sampling_config["top_p"] = request.top_p + if request.temperature is not None: + sampling_config["temperature"] = request.temperature + if request.max_tokens is not None: + max_new_tokens = min(request.max_tokens, config.max_output_length) + else: + max_new_tokens = config.max_output_length + if request.messages[-1].role not in ["user", "function"]: + print(454) + raise HTTPException(status_code=400, detail="Invalid request") + # query = request.messages[-1].content + + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == "system": + system = prev_messages.pop(0).content + else: + system = "You are a helpful assistant." + + # history = [] + # if len(prev_messages) % 2 == 0: + # for i in range(0, len(prev_messages), 2): + # if ( + # prev_messages[i].role == "user" + # and prev_messages[i + 1].role == "assistant" + # ): + # history.append( + # [ + # prev_messages[i].content, + # prev_messages[i + 1].content, + # ] + # ) + stop_words = add_extra_stop_words(request.stop) + if request.functions: + stop_words = stop_words or [] + if "Observation:" not in stop_words: + stop_words.append("Observation:") + + query, history = parse_messages(request.messages, request.functions) + # print("query: ", query) + # print("history: ", history) + + if request.stream: + if request.functions: + raise HTTPException( + status_code=400, + detail="Invalid request: Function calling is not yet implemented for stream mode.", + ) + return EventSourceResponse( + stream_predict(query, system, history, sampling_config, max_new_tokens, request.model), + media_type="text/event-stream" + ) + stop_words_ids = [infer_engine.tokenizer.encode(s) for s in stop_words] if stop_words else None + # print("gen kwargs",gen_kwargs) + if query is _TEXT_COMPLETION_CMD: + response = text_complete_last_message( + history, + stop_words_ids=stop_words_ids, + sampling_config=sampling_config, + max_new_tokens=max_new_tokens + ) + else: + query_text = query.lstrip("\n").strip() + response = infer_engine.predict( + prompt=query_text, + system_prompt=system, + history=history, + sampling_config=sampling_config, + max_new_tokens=max_new_tokens, + show_progress=show_progress, + ) + response = trim_stop_words(response, stop_words) + if request.functions: + choice_data = parse_response(response) + else: + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + + return ChatCompletionResponse( + model=request.model, + choices=[choice_data], + object="chat.completion", + ) + + +async def stream_predict( + query: str, + system: str, + history: List[List[str]], + sampling_config: dict, + max_new_tokens: int, + model_id: str +): + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionResponse( + model=model_id, + choices=[choice_data], + object="chat.completion.chunk", + ) + yield "{}".format( + chunk.model_dump_json(exclude_unset=True) + ) + # print("Debug system", system) + # print("Debug query", query) + # print("Debug history", history) + for new_text in infer_engine.stream_predict( + prompt=query, + history=history, + sampling_config=sampling_config, + system_prompt=system, + max_new_tokens=max_new_tokens, + show_progress=show_progress, + ): + if len(new_text) == 0: + continue + # print("Debug, new_text[0]: ", new_text[0]) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=new_text), + finish_reason=None, + ) + chunk = ChatCompletionResponse( + model=model_id, + choices=[choice_data], + object="chat.completion.chunk", + ) + yield "{}".format( + chunk.model_dump_json(exclude_unset=True) + ) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(), finish_reason="stop" + ) + chunk = ChatCompletionResponse( + model=model_id, + choices=[choice_data], + object="chat.completion.chunk", + ) + yield "{}".format( + chunk.model_dump_json(exclude_unset=True) + ) + yield "[DONE]" + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + # uvicorn.run(app, host="localhost", port=8000, workers=1) + diff --git a/cli_chat.py b/cli_chat.py index d6809fc..7010972 100644 --- a/cli_chat.py +++ b/cli_chat.py @@ -7,78 +7,66 @@ import os project_dir = os.path.dirname(os.path.abspath(__file__)) -parser = argparse.ArgumentParser() -parser.add_argument( - '--hf_model_dir', - type=str, - help="model and tokenizer path, only support huggingface model", - default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") -) -parser.add_argument( - "--session_type", - type=str, - default="acl", - help="acl or onnx", - choices=["acl", "onnx"], -) -parser.add_argument( - '--onnx_model_path', - type=str, - help="onnx_model_path", - default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx") -) -parser.add_argument( - "--om_model_path", - help="mindspore model path", - type=str, - default= os.path.join(project_dir, "output", "model", "qwen2_1.5b_chat.om") -) -parser.add_argument( - "--max_batch", - help="max batch", - type=int, - default=1, -) -parser.add_argument( - "--max_input_length", - help="max input length", - type=int, - default=512, -) -parser.add_argument( - "--max_prefill_length", - help="max prefill length in first inference. " - "Attention max_prefill_length + max_output_length <= kv_cache_length. " - "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " - "Note! The higher this number, the longer it will take to compile.", - type=int, - default=8, -) +def parser_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--hf_model_dir', + type=str, + help="model and tokenizer path, only support huggingface model", + default=os.path.join(project_dir, "download", "Qwen2-1.5B-Instruct") + ) + parser.add_argument( + "--session_type", + type=str, + default="acl", + help="acl or onnx", + choices=["acl", "onnx"], + ) + parser.add_argument( + '--onnx_model_path', + type=str, + help="onnx_model_path", + default=os.path.join(project_dir, "output", "onnx", "qwen2_1.5b_chat.onnx") + ) + parser.add_argument( + "--om_model_path", + help="mindspore model path", + type=str, + default= os.path.join(project_dir, "output", "model", "qwen2_1.5b_chat.om") + ) + parser.add_argument( + "--max_batch", + help="max batch", + type=int, + default=1, + ) + parser.add_argument( + "--max_input_length", + help="max input length", + type=int, + default=512, + ) + parser.add_argument( + "--max_prefill_length", + help="max prefill length in first inference. " + "Attention max_prefill_length + max_output_length <= kv_cache_length. " + "the number must by 2^xx, like 1, 2, 4, 8, 16, 32, 64, 128, 256... " + "Note! The higher this number, the longer it will take to compile.", + type=int, + default=8, + ) -parser.add_argument( - "--max_output_length", - help="max output length (contain input + new token)", - type=int, - default=1024, -) + parser.add_argument( + "--max_output_length", + help="max output length (contain input + new token)", + type=int, + default=1024, + ) + return parser.parse_args() -args = parser.parse_args() -max_prefill_log2 = int(math.log2(args.max_prefill_length)) -max_prefill_length = 2 ** max_prefill_log2 -config = InferenceConfig( - hf_model_dir=args.hf_model_dir, - om_model_path=args.om_model_path, - onnx_model_path=args.onnx_model_path, - session_type=args.session_type, - max_batch=args.max_batch, - max_output_length=args.max_output_length, - max_input_length=args.max_input_length, - kv_cache_length=args.max_output_length, - max_prefill_length=max_prefill_length, -) -infer_engine=Inference(config) def inference_cli(): + infer_engine = Inference(config) print("\n欢迎使用Qwen聊天机器人,输入exit或者quit退出,输入clear清空历史记录") history = [] while True: @@ -98,7 +86,7 @@ def inference_cli(): first_token_lantency, decode_speed, total_speed - ) in infer_engine.stream_predict(input_text, history=history): + ) in infer_engine.stream_predict(input_text, history=history, do_speed_test=True): if is_first: if len(new_text.strip()) == 0: continue @@ -112,7 +100,21 @@ def inference_cli(): " total_speed(prefill+decode): {:.2f} token/s".format(total_speed), ) - history.append({"role": "assistant", "content": response}) + history.append([input_text, response]) if __name__ == '__main__': + args = parser_args() + max_prefill_log2 = int(math.log2(args.max_prefill_length)) + max_prefill_length = 2 ** max_prefill_log2 + config = InferenceConfig( + hf_model_dir=args.hf_model_dir, + om_model_path=args.om_model_path, + onnx_model_path=args.onnx_model_path, + session_type=args.session_type, + max_batch=args.max_batch, + max_output_length=args.max_output_length, + max_input_length=args.max_input_length, + kv_cache_length=args.max_output_length, + max_prefill_length=max_prefill_length, + ) # main() inference_cli() \ No newline at end of file diff --git a/client/openai_function_call.py b/client/openai_function_call.py new file mode 100644 index 0000000..574cadd --- /dev/null +++ b/client/openai_function_call.py @@ -0,0 +1,205 @@ +from openai import OpenAI +import os +import requests +import urllib3 +import time +import random +import json + + +urllib3.disable_warnings() + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="no api" +) + +# get api from here https://dev.qweather.com/ +weather_key = "" +if len(weather_key) == 0: + weather_key = os.environ["WEATHER_KEY"] +assert len(weather_key) > 0, print("please get weather query api in https://dev.qweather.com/") + + +class Weather: + def __init__(self, api_key): + self.api_key = api_key + + def get_location_from_api(self, location, adm=None, + location_range="world", lang="zh"): + """ + Get api based on https:dev.qweather.com + params location: the location to be queried + params adm: superior region, for example, the superior region of Yuexiu is Guangzhou + params location_range: query range, default global, supports cn: China, us: United States, fr: France, + uk: United Kingdom, please check the iso-3166 standard for more information + params lang: language, default zh, support en + """ + url = "https://geoapi.qweather.com/v2/city/lookup?" + params = { + "key": self.api_key, + "location": location, + "range": location_range, + "lang": lang, + } + if adm is not None: + if len(adm) > 0: + params["adm"] = adm + session = requests.session() + try: + res2 = session.get(url, params=params, verify=False, timeout=15) + if res2.status_code == 200: + data = res2.json() + if data.get("code", None) == '200': + return data.get("location", []) + else: + print(data) + else: + print(res2) + time.sleep(1 + random.random()) + session.close() + except Exception as err: + print("request error", err) + time.sleep(3 + random.random()) + session.close() + return [] + + def get_weather_from_api(self, location: str): + """ + Get weather information from Zefeng weather api + :param location: location information, which can be location_id or a latitude and longitude (format: "longitude, latitude") + """ + url = "https://devapi.qweather.com/v7/weather/3d?" + params = { + "location": location, + "key": self.api_key + } + session = requests.session() + try: + res1 = session.get(url, params=params, verify=False, timeout=15) + if res1.status_code == 200: + data = res1.json() + if data.get("code", "") == "200": + return data.get("daily", []) + else: + print(data) + else: + print(res1) + time.sleep(1 + random.random()) + session.close() + except Exception as err: + print("get api error,", err) + time.sleep(3 + random.random()) + session.close() + return [] + + +def get_current_weather(location: str): + weather = Weather(weather_key) + location_data = weather.get_location_from_api(location) + if len(location_data) > 0: + location_dict = location_data[0] + city_id = location_dict["id"] + weather_res = weather.get_weather_from_api(city_id) + n_day = len(weather_res) + return f"查询到最近{n_day}天的天气。" + json.dumps(weather_res, ensure_ascii=False) + else: + return "" + +def call_qwen(messages, functions=None): + # print(messages) + if functions: + response = client.chat.completions.create( + model="Qwen", messages=messages, functions=functions + ) + else: + response = client.chat.completions.create( + model="Qwen", messages=messages + ) + # print(response) + # print(response.choices[0].message.content) + return response + + +def chat(query: str): + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + + messages = [ + { + "role": "user", + # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, + # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. + "content": query, + } + ] + response = call_qwen(messages, functions) + res = response.choices[0].message + message_dict = { + "role": res.role, + "content": res.content, + "function_call": res.function_call, + } + messages.append(message_dict) + # --- call function --- # + if res.function_call is not None: + function_call = res.function_call + function_name = function_call.name + try: + function_params = json.loads(function_call.arguments) + except: + print(f"{function_name}解析对应参数失败,请检查, 参数信息:", function_call) + return + for temp_dict in functions: + if temp_dict["name"] == function_name: + require_params = temp_dict["parameters"]["required"] + # require_params.sort() + had_params = list(function_params.keys()) + # had_params.sort() + for param in had_params: + if param not in require_params: + del function_params[param] + # recompute + had_params = list(function_params.keys()) + if len(had_params) != len(require_params): + raise Exception("ERROR, need to do other fill params") + + + response = eval(function_name)(**function_params) + message = { + "role": "function", + "name": function_name, + } + if len(response) > 0: + message["content"] = response + else: + message["content"] = "未找到任何信息" + messages.append(message) + response = call_qwen(messages, functions) + return response + + +messages = [{"role": "system", "content": "You are a helpful assistant."}] +print("=" * 20) +# print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录") +print("目前已支持天气查询插件") +print("=" * 20) +query = "北京天气如何?穿短袖会不会冷?" +print("用户输入:", query) +res = chat(query) +print("回答结果:", res.choices[0].message.content) diff --git a/client/openai_normal_client.py b/client/openai_normal_client.py new file mode 100644 index 0000000..f11c745 --- /dev/null +++ b/client/openai_normal_client.py @@ -0,0 +1,30 @@ +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="no api" +) + +messages = [{"role": "system", "content": "You are a helpful assistant."}] +print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录") +while True: + prompt = input('Human:') + if prompt == 'exit' or prompt == "exit()": + break + if prompt == 'clear': + messages = messages[:1] + continue + messages.append({"role": "user", "content": prompt}) + completion = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=messages, + top_p=0.5, + temperature=0, + n=1, + max_tokens=4096, + stream=False, + ) + message = completion.choices[0].message + response_text = message.content + print('ChatBot: {}'.format(response_text)) + messages.append({"role": "assistant", "content": response_text}) \ No newline at end of file diff --git a/client/openai_stream_client.py b/client/openai_stream_client.py new file mode 100644 index 0000000..191be54 --- /dev/null +++ b/client/openai_stream_client.py @@ -0,0 +1,39 @@ +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="no api" +) + + +messages = [{"role": "system", "content": "You are a helpful assistant."}] +print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录") +while True: + prompt = input('Human:') + if prompt == 'exit' or prompt == "exit()": + break + if prompt == 'clear': + messages = messages[:1] + continue + messages.append({"role": "user", "content": prompt}) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=messages, + top_p=0.5, + temperature=0, + n=1, + max_tokens=4096, + stream=True, + ) + print("ChatBot:", end='', flush=True) + response_text = "" + for event in response: + # print(event) + event_text = event.choices[0].delta.content # extract the text + if event_text is None: + event_text = "" + response_text += event_text + print(event_text, end='', flush=True) + messages.append({"role": "assistant", "content": response_text}) + print("") + diff --git a/requirements.txt b/requirements.txt index 2ff5eac..9ef3435 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,9 @@ onnx==1.16.1 onnxruntime==1.18.1 # onnxruntime-cann==1.18.1 torch==2.1.0 -torch-npu==2.1.0.post6 \ No newline at end of file +torch-npu==2.1.0.post6 +tqdm +fastapi +uvicorn +sse_starlette +openai \ No newline at end of file diff --git a/utils/inference.py b/utils/inference.py index f2a410a..469cae0 100644 --- a/utils/inference.py +++ b/utils/inference.py @@ -7,6 +7,7 @@ from threading import Lock from utils.session import Session from config import InferenceConfig +from tqdm import trange, tqdm @@ -114,21 +115,29 @@ def stream_predict( self, prompt, history=None, - system_prompt: str="You are a helpful assistant.", + sampling_config: dict = {}, + system_prompt: str = "You are a helpful assistant.", + max_new_tokens: int = 512, + do_speed_test: bool = False, + show_progress: bool = False, ): if history is None: - history = [] - if len(history) == 0: - history = [{"role": "system", "content": system_prompt}] + history = [] + sampling_value = sampling_config.get("sampling_value", self.sampling_value) + temperature = sampling_config.get("sampling_value", self.temperature) + messages = [{"role": "system", "content": system_prompt}] # print("prompt: ", prompt) with self.lock: self.state['isEnd'],self.state['message'] = False,"" if prompt == "": - return - history.append({"role": "user", "content": prompt}) + return + for (use_msg, bot_msg) in history: + messages.append({"role": "user", "content": use_msg}) + messages.append({"role": "assistant", "content": bot_msg}) + messages.append({"role": "user", "content": prompt}) # print("history: ", history) text = self.tokenizer.apply_chat_template( - history, + messages, tokenize=False, add_generation_prompt=True ) @@ -139,20 +148,27 @@ def stream_predict( ids_list = [] text_length = 0 input_length = input_ids.shape[1] - start = time.time() - first_token_latency = 0 - decode_speed = 0 + if do_speed_test: + start = time.time() + first_token_latency = 0 + decode_speed = 0 max_output_len = self.max_output_length - input_length - for i in range(max_output_len): - logits = self.session.run(input_ids)[0] + max_output_len = min(max_output_len, max_new_tokens) + if show_progress: + temp_list = trange(max_output_len, desc="decode") + else: + temp_list = range(max_output_len) + for i in temp_list: + prefill_show_progress = (i == 0) + logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0] input_ids = self.sample_logits( logits[0][-1:], self.sampling_method, - self.sampling_value, - self.temperature + sampling_value, + temperature ) input_ids = input_ids.reshape(1, -1) - if i == 0: + if do_speed_test and i == 0: first_token_latency = time.time() - start with self.lock: # early stop @@ -164,17 +180,135 @@ def stream_predict( # stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"]) self.state['message'] = text_out new_text = text_out[text_length: ] - # decode_speed = - duration = time.time() - start - decode_speed = len(ids_list) / duration - totol_speed = (input_length + len(ids_list)) / duration + if do_speed_test: + duration = time.time() - start + decode_speed = len(ids_list) / duration + totol_speed = (input_length + len(ids_list)) / duration if b"\xef\xbf\xbd" in new_text.encode(): continue if len(new_text) > 0: - yield new_text, first_token_latency, decode_speed, totol_speed + if do_speed_test: + yield new_text, first_token_latency, decode_speed, totol_speed + else: + yield new_text text_length = len(text_out) + with self.lock: + self.state['isEnd'] = True + + def predict( + self, + prompt, + history=None, + sampling_config: dict = {}, + system_prompt: str="You are a helpful assistant.", + max_new_tokens: int = 512, + show_progress: bool = False, + ): + if history is None: + history = [] + sampling_value = sampling_config.get("sampling_value", self.sampling_value) + temperature = sampling_config.get("sampling_value", self.temperature) + messages = [{"role": "system", "content": system_prompt}] + # print("prompt: ", prompt) + with self.lock: + self.state['isEnd'],self.state['message'] = False,"" + if prompt == "": + return + for (use_msg, bot_msg) in history: + messages.append({"role": "user", "content": use_msg}) + messages.append({"role": "assistant", "content": bot_msg}) + messages.append({"role": "user", "content": prompt}) + # print("history: ", history) + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + input_ids = self.tokenizer( + [text], return_tensors="np" + )["input_ids"].astype(np.int64).reshape(1, -1) + self.first = False + ids_list = [] + # text_length = 0 + input_length = input_ids.shape[1] + # start = time.time() + # first_token_latency = 0 + # decode_speed = 0 + max_output_len = self.max_output_length - input_length + max_output_len = min(max_output_len, max_new_tokens) + if show_progress: + temp_list = trange(max_output_len, desc="decode") + else: + temp_list = range(max_output_len) + for i in temp_list: + prefill_show_progress = (i == 0) + logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0] + input_ids = self.sample_logits( + logits[0][-1:], + self.sampling_method, + sampling_value, + temperature + ) + input_ids = input_ids.reshape(1, -1) + # if i == 0: + # first_token_latency = time.time() - start + with self.lock: + # early stop + if input_ids[0] == self.tokenizer.eos_token_id: + self.state['message'],self.state['isEnd'] = self.tokenizer.decode(ids_list),True + break + ids_list.append(input_ids[0].item()) + # text_out = self.tokenizer.decode(ids_list) + # stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"]) + # self.state['message'] = text_out + # decode_speed = + with self.lock: + self.state['isEnd'] = True + text_out = self.tokenizer.decode(ids_list) + return text_out + + def generate( + self, + input_ids, + sampling_config: dict = {}, + max_new_tokens: int = 512, + show_progress: bool = False, + ): + sampling_value = sampling_config.get("sampling_value", self.sampling_value) + temperature = sampling_config.get("sampling_value", self.temperature) + self.first = False + ids_list = [] + input_length = input_ids.shape[1] + max_output_len = self.max_output_length - input_length + max_output_len = min(max_output_len, max_new_tokens) + if show_progress: + temp_list = trange(max_output_len, desc="decode") + else: + temp_list = range(max_output_len) + for i in temp_list: + prefill_show_progress = (i == 0) + logits = self.session.run(input_ids, show_progress=prefill_show_progress)[0] + input_ids = self.sample_logits( + logits[0][-1:], + self.sampling_method, + sampling_value, + temperature + ) + input_ids = input_ids.reshape(1, -1) + with self.lock: + # early stop + if input_ids[0] == self.tokenizer.eos_token_id: + self.state['message'],self.state['isEnd'] = self.tokenizer.decode(ids_list),True + break + ids_list.append(input_ids[0].item()) + text_out = self.tokenizer.decode(ids_list) + print("Debug: ", text_out) + # stop_word = is_stop_word_or_prefix(text_out, ["[|Human|]", "[|AI|]"]) + self.state['message'] = text_out with self.lock: self.state['isEnd'] = True + text_out = self.tokenizer.decode(ids_list) + return text_out def reset(self): self.first = True diff --git a/utils/session.py b/utils/session.py index 53be187..e76b361 100644 --- a/utils/session.py +++ b/utils/session.py @@ -7,12 +7,15 @@ import sys from utils.engine import ACLModel, init_resource, destroy_resource import onnxruntime as ort +from tqdm import tqdm, trange + class Session: def __init__(self, config: InferenceConfig) -> None: self.kv_cache = create_kv_cache(config) self.run_times = 0 - def run(self,input_ids:np.ndarray): + + def run(self,input_ids:np.ndarray, show_progress: bool = False): pass @staticmethod @@ -46,7 +49,7 @@ def __init__(self,config:InferenceConfig)->None: ], ) - def run(self, input_ids:np.ndarray): + def run(self, input_ids:np.ndarray, show_progress=False): seq_len=input_ids.shape[-1] cache, mask, pos_ids = self.kv_cache.get_inputs(seq_len) result = self.llm_session.run(None,{ @@ -142,29 +145,42 @@ def decompose_number(self, n, start_index=0): return [power] + self.decompose_number(n - power, i) return [] - def run(self, input_ids: np.ndarray): + def run(self, input_ids: np.ndarray, show_progress:bool=False): seq_len = input_ids.shape[-1] logits = None is_dynamic = bool(self.max_prefill_length > 1) # dynamic inference if is_dynamic: seq_list = self.decompose_number(seq_len) + if show_progress: + seq_list = tqdm(seq_list, desc="prefill") start_i = 0 for seq in seq_list: end_i = start_i + seq logits = self.run_some( input_ids[:, start_i: end_i], seq, - is_dynamic + is_dynamic, ) start_i += seq + # if show_progress: + # seq_list.update(seq) # static inference else: - for i in range(seq_len): + if show_progress: + idx_list = trange(seq_len, desc="prefill") + else: + idx_list = range(seq_len) + for i in idx_list: logits = self.run_some(input_ids[:,i]) return [logits] - def run_some(self, input_ids: np.ndarray, seq_length: int = 1, is_dynamic: bool = False): + def run_some( + self, + input_ids: np.ndarray, + seq_length: int = 1, + is_dynamic: bool = False + ): self.run_times += seq_length cache, mask, pos_ids = self.kv_cache.get_inputs(seq_length) result:List[np.ndarray] = self.model.inference(