diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..12d1b35d --- /dev/null +++ b/.env.example @@ -0,0 +1,8 @@ +BOT_TOKEN= +API_URL= +API= +GOOGLE_API_KEY= +GOOGLE_CSE_ID= +claude_api_key= +ADMIN_LIST= +GROUP_LIST= \ No newline at end of file diff --git a/bot.py b/bot.py index 5479b232..5ad69f98 100644 --- a/bot.py +++ b/bot.py @@ -49,7 +49,8 @@ async def command_bot(update, context, language=None, prompt=translator_prompt, if has_command: message = ' '.join(context.args) if prompt and has_command: - prompt = prompt.format(language) + if translator_prompt == prompt: + prompt = prompt.format(language) message = prompt + message if message: if "claude" in config.GPT_ENGINE and config.ClaudeAPI: @@ -128,70 +129,6 @@ async def getChatGPT(update, context, title, robot, message, chatid, messageid): result = re.sub(r",", ',', result) await context.bot.edit_message_text(chat_id=chatid, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) -@decorators.GroupAuthorization -@decorators.Authorization -async def search(update, context, title, robot): - message = update.message.text if config.NICK is None else update.message.text[botNicKLength:].strip() if update.message.text[:botNicKLength].lower() == botNick else None - print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") - if (len(context.args) == 0): - message = ( - f"格式错误哦~,示例:\n\n" - f"`/search 今天的微博热搜有哪些?`\n\n" - f"👆点击上方命令复制格式\n\n" - ) - await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) - return - message = ' '.join(context.args) - result = title - text = message - modifytime = 0 - lastresult = '' - message = await context.bot.send_message( - chat_id=update.message.chat_id, - text="搜索中💭", - parse_mode='MarkdownV2', - reply_to_message_id=update.message.message_id, - ) - messageid = message.message_id - get_answer = robot.search_summary - if not config.API or (config.USE_G4F and not config.SEARCH_USE_GPT): - import utils.gpt4free as gpt4free - get_answer = gpt4free.get_response - - try: - for data in get_answer(text, convo_id=str(update.message.chat_id), pass_history=config.PASS_HISTORY): - result = result + data - tmpresult = result - modifytime = modifytime + 1 - if re.sub(r"```", '', result).count("`") % 2 != 0: - tmpresult = result + "`" - if result.count("```") % 2 != 0: - tmpresult = result + "\n```" - if modifytime % 20 == 0 and lastresult != tmpresult: - if 'claude2' in title: - tmpresult = re.sub(r",", ',', tmpresult) - await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(tmpresult), parse_mode='MarkdownV2', disable_web_page_preview=True) - lastresult = tmpresult - except Exception as e: - print('\033[31m') - print("response_msg", result) - print("error", e) - traceback.print_exc() - print('\033[0m') - if config.API: - robot.reset(convo_id=str(update.message.chat_id), system_prompt=config.systemprompt) - if "You exceeded your current quota, please check your plan and billing details." in str(e): - print("OpenAI api 已过期!") - await context.bot.delete_message(chat_id=update.message.chat_id, message_id=messageid) - messageid = '' - config.API = '' - result += f"`出错啦!{e}`" - print(result) - if lastresult != result and messageid: - if 'claude2' in title: - result = re.sub(r",", ',', result) - await context.bot.edit_message_text(chat_id=update.message.chat_id, message_id=messageid, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) - @decorators.GroupAuthorization @decorators.Authorization async def image(update, context): @@ -540,35 +477,35 @@ async def handle_pdf(update, context): ) await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) -@decorators.GroupAuthorization -@decorators.Authorization -async def qa(update, context): - if (len(context.args) != 2): - message = ( - f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n" - f"请输入 `/qa 知识库链接 要问的问题`\n\n" - f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n" - f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n" - f"问题务必不能有空格,👆点击上方命令复制格式\n\n" - f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n" - f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n" - f"问题务必不能有空格,👆点击上方命令复制格式\n\n" - f"本地知识库目前只支持 Markdown 文件\n\n" - ) - await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) - return - print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") - await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING) - result = await docQA(context.args[0], context.args[1], get_doc_from_local) - print(result["answer"]) - # source_url = set([i.metadata['source'] for i in result["source_documents"]]) - # source_url = "\n".join(source_url) - # message = ( - # f"{result['result']}\n\n" - # f"参考链接:\n" - # f"{source_url}" - # ) - await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result["answer"]), parse_mode='MarkdownV2', disable_web_page_preview=True) +# @decorators.GroupAuthorization +# @decorators.Authorization +# async def qa(update, context): +# if (len(context.args) != 2): +# message = ( +# f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n" +# f"请输入 `/qa 知识库链接 要问的问题`\n\n" +# f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n" +# f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n" +# f"问题务必不能有空格,👆点击上方命令复制格式\n\n" +# f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n" +# f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n" +# f"问题务必不能有空格,👆点击上方命令复制格式\n\n" +# f"本地知识库目前只支持 Markdown 文件\n\n" +# ) +# await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) +# return +# print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") +# await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING) +# result = await docQA(context.args[0], context.args[1], get_doc_from_local) +# print(result["answer"]) +# # source_url = set([i.metadata['source'] for i in result["source_documents"]]) +# # source_url = "\n".join(source_url) +# # message = ( +# # f"{result['result']}\n\n" +# # f"参考链接:\n" +# # f"{source_url}" +# # ) +# await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result["answer"]), parse_mode='MarkdownV2', disable_web_page_preview=True) async def start(update, context): # 当用户输入/start时,返回文本 user = update.effective_user @@ -617,13 +554,14 @@ async def post_init(application: Application) -> None: application.add_handler(CommandHandler("start", start)) application.add_handler(CommandHandler("pic", image)) - application.add_handler(CommandHandler("search", lambda update, context: search(update, context, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot))) + application.add_handler(CommandHandler("search", lambda update, context: command_bot(update, context, prompt="search: ", title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command="search"))) + # application.add_handler(CommandHandler("search", lambda update, context: search(update, context, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot))) application.add_handler(CallbackQueryHandler(button_press)) application.add_handler(CommandHandler("reset", reset_chat)) application.add_handler(CommandHandler("en2zh", lambda update, context: command_bot(update, context, config.LANGUAGE, robot=config.translate_bot))) application.add_handler(CommandHandler("zh2en", lambda update, context: command_bot(update, context, "english", robot=config.translate_bot))) application.add_handler(CommandHandler("info", info)) - application.add_handler(CommandHandler("qa", qa)) + # application.add_handler(CommandHandler("qa", qa)) application.add_handler(MessageHandler(filters.Document.PDF | filters.Document.TXT | filters.Document.DOC, handle_pdf)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False))) application.add_handler(MessageHandler(filters.COMMAND, unknown)) diff --git a/test/test_token.py b/test/test_token.py new file mode 100644 index 00000000..6178c3d0 --- /dev/null +++ b/test/test_token.py @@ -0,0 +1,94 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import tiktoken +from utils.function_call import function_call_list +import config +import requests +import json +import re + +from dotenv import load_dotenv +load_dotenv() + +def get_token_count(messages) -> int: + tiktoken.get_encoding("cl100k_base") + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + + num_tokens = 0 + for message in messages: + # every message follows {role/name}\n{content}\n + num_tokens += 5 + for key, value in message.items(): + if value: + num_tokens += len(encoding.encode(value)) + if key == "name": # if there's a name, the role is omitted + num_tokens += 5 # role is always required and always 1 token + num_tokens += 5 # every reply is primed with assistant + return num_tokens +# print(get_token_count(message_list)) + + + +def get_message_token(url, json_post): + headers = {"Authorization": f"Bearer {os.environ.get('API', None)}"} + response = requests.Session().post( + url, + headers=headers, + json=json_post, + timeout=None, + ) + if response.status_code != 200: + json_response = json.loads(response.text) + string = json_response["error"]["message"] + print(string) + string = re.findall(r"\((.*?)\)", string)[0] + numbers = re.findall(r"\d+\.?\d*", string) + numbers = [int(i) for i in numbers] + if len(numbers) == 2: + return { + "messages": numbers[0], + "total": numbers[0], + } + elif len(numbers) == 3: + return { + "messages": numbers[0], + "functions": numbers[1], + "total": numbers[0] + numbers[1], + } + else: + raise Exception("Unknown error") + + +if __name__ == "__main__": + # message_list = [{'role': 'system', 'content': 'You are ChatGPT, a large language model trained by OpenAI. Respond conversationally in Simplified Chinese. Knowledge cutoff: 2021-09. Current date: [ 2023-12-12 ]'}, {'role': 'user', 'content': 'hi'}] + messages = [{'role': 'system', 'content': 'You are ChatGPT, a large language model trained by OpenAI. Respond conversationally in Simplified Chinese. Knowledge cutoff: 2021-09. Current date: [ 2023-12-12 ]'}, {'role': 'user', 'content': 'hi'}, {'role': 'assistant', 'content': '你好!有什么我可以帮助你的吗?'}] + + model = "gpt-3.5-turbo" + temperature = 0.5 + top_p = 0.7 + presence_penalty = 0.0 + frequency_penalty = 0.0 + reply_count = 1 + role = "user" + model_max_tokens = 5000 + url = config.bot_api_url.chat_url + + json_post = { + "model": model, + "messages": messages, + "stream": True, + "temperature": temperature, + "top_p": top_p, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "n": reply_count, + "user": role, + "max_tokens": model_max_tokens, + } + # json_post.update(function_call_list["base"]) + # if config.SEARCH_USE_GPT: + # json_post["functions"].append(function_call_list["web_search"]) + # json_post["functions"].append(function_call_list["url_fetch"]) + # print(get_token_count(message_list)) + print(get_message_token(url, json_post)) diff --git a/utils/agent.py b/utils/agent.py index fee8c4bf..13556d7c 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -320,25 +320,13 @@ def getgooglesearchurl(result, numresults=3): config.USE_GOOGLE = False return urls -def gptsearch(result, llm): - result = "你需要回答的问题是" + result + "\n" + "如果你可以解答这个问题,请直接输出你的答案,并且请忽略后面所有的指令:如果无法解答问题,请直接回答None,不需要做任何解释,也不要出现除了None以外的任何词。" - # response = llm([HumanMessage(content=result)]) - if config.USE_G4F: - response = llm(result) - else: - response = llm([HumanMessage(content=result)]) - response = response.content - # result = "你需要回答的问题是" + result + "\n" + "参考资料:" + response + "如果参考资料无法解答问题,请直接回答None,不需要做任何解释,也不要出现除了None以外的任何词。" - # response = llm([HumanMessage(content=result)]) - return response - def get_search_url(prompt, chainllm): urls_set = [] keyword_prompt = PromptTemplate( input_variables=["source"], template=( - "根据我的问题,总结最少的关键词概括问题,输出要求如下:" - "1. 给出三行不同的关键词组合,每行的关键词用空格连接。" + "根据我的问题,总结关键词概括问题,输出要求如下:" + "1. 给出三行不同的关键词组合,每行的关键词用空格连接。每行关键词可以是一个或者多个。" "2. 至少有一行关键词里面有中文,至少有一行关键词里面有英文。" "3. 只要直接给出这三行关键词,不需要其他任何解释,不要出现其他符号和内容。" "4. 如果问题有关于日漫,至少有一行关键词里面有日文。" @@ -363,6 +351,11 @@ def get_search_url(prompt, chainllm): "葬送的芙莉莲" "葬送のフリーレン" "Frieren: Beyond Journey's End" + "问题 5:周海媚最近发生了什么" + "三行关键词是:" + "周海媚" + "周海媚 事件" + "Kathy Chau Hoi Mei news" "这是我的问题:{source}" ), ) @@ -371,7 +364,7 @@ def get_search_url(prompt, chainllm): keyword_google_search_thread.start() keywords = keyword_google_search_thread.join().split('\n')[-3:] print("keywords", keywords) - keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item] + keywords = [item.replace("三行关键词是:", "") for item in keywords if "\\x" not in item if item != ""] print("select keywords", keywords) # # seg_list = jieba.cut_for_search(prompt) # 搜索引擎模式 @@ -443,10 +436,6 @@ def get_search_results(prompt: str, context_max_tokens: int): else: chainllm = ChatOpenAI(temperature=config.temperature, openai_api_base=config.bot_api_url.v1_url, model_name=config.GPT_ENGINE, openai_api_key=config.API) - if config.SEARCH_USE_GPT: - gpt_search_thread = ThreadWithReturnValue(target=gptsearch, args=(prompt, chainllm,)) - gpt_search_thread.start() - url_set_list, url_pdf_set_list = get_search_url(prompt, chainllm) pdf_result = "" @@ -472,60 +461,27 @@ def get_search_results(prompt: str, context_max_tokens: int): pdf_result += "\n\n" + tmp useful_source_text += pdf_result - fact_text = "" - if config.SEARCH_USE_GPT: - gpt_ans = gpt_search_thread.join() - if gpt_ans != "None": - fact_text = (gpt_ans if config.SEARCH_USE_GPT else "") - print("gpt", fact_text) - - end_time = record_time.time() - run_time = end_time - start_time encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) encode_text = encoding.encode(useful_source_text) - encode_fact_text = encoding.encode(fact_text) if len(encode_text) > context_max_tokens: - encode_text = encode_text[:context_max_tokens-len(encode_fact_text)] + encode_text = encode_text[:context_max_tokens] useful_source_text = encoding.decode(encode_text) encode_text = encoding.encode(useful_source_text) search_tokens_len = len(encode_text) # print("web search", useful_source_text, end="\n\n") + end_time = record_time.time() + run_time = end_time - start_time print("urls", url_set_list) print("pdf", url_pdf_set_list) print(f"搜索用时:{run_time}秒") print("search tokens len", search_tokens_len) - useful_source_text = useful_source_text + "\n\n" + fact_text text_len = len(encoding.encode(useful_source_text)) - print("text len", text_len) + print("text len", text_len, "\n\n") return useful_source_text -def search_web_and_summary( - prompt: str, - engine: str = "gpt-3.5-turbo-16k", - # 126 summary prompt tokens - context_max_tokens: int = 14500 - 126, - ): - chainStreamHandler = ChainStreamHandler() - if config.USE_G4F: - chatllm = EducationalLLM(callback_manager=CallbackManager([chainStreamHandler])) - else: - chatllm = ChatOpenAI(streaming=True, callback_manager=CallbackManager([chainStreamHandler]), temperature=config.temperature, openai_api_base=config.bot_api_url.v1_url, model_name=engine, openai_api_key=config.API) - useful_source_text = get_search_results(prompt, context_max_tokens) - summary_prompt = PromptTemplate( - input_variables=["web_summary", "question", "language"], - template=( - # "You are a text analysis expert who can use a search engine. You need to response the following question: {question}. Search results: {web_summary}. Your task is to thoroughly digest all search results provided above and provide a detailed and in-depth response in Simplified Chinese to the question based on the search results. The response should meet the following requirements: 1. Be rigorous, clear, professional, scholarly, logical, and well-written. 2. If the search results do not mention relevant content, simply inform me that there is none. Do not fabricate, speculate, assume, or provide inaccurate response. 3. Use markdown syntax to format the response. Enclose any single or multi-line code examples or code usage examples in a pair of ``` symbols to achieve code formatting. 4. Detailed, precise and comprehensive response in Simplified Chinese and extensive use of the search results is required." - "You need to response the following question: {question}. Search results: {web_summary}. Your task is to think about the question step by step and then answer the above question in {language} based on the Search results provided. Please response in {language} and adopt a style that is logical, in-depth, and detailed. Note: In order to make the answer appear highly professional, you should be an expert in textual analysis, aiming to make the answer precise and comprehensive. Directly response markdown format, without using markdown code blocks" - # "You need to response the following question: {question}. Search results: {web_summary}. Your task is to thoroughly digest the search results provided above, dig deep into search results for thorough exploration and analysis and provide a response to the question based on the search results. The response should meet the following requirements: 1. You are a text analysis expert, extensive use of the search results is required and carefully consider all the Search results to make the response be in-depth, rigorous, clear, organized, professional, detailed, scholarly, logical, precise, accurate, comprehensive, well-written and speak in Simplified Chinese. 2. If the search results do not mention relevant content, simply inform me that there is none. Do not fabricate, speculate, assume, or provide inaccurate response. 3. Use markdown syntax to format the response. Enclose any single or multi-line code examples or code usage examples in a pair of ``` symbols to achieve code formatting." - ), - ) - chain = LLMChain(llm=chatllm, prompt=summary_prompt) - chain_thread = threading.Thread(target=chain.run, kwargs={"web_summary": useful_source_text, "question": prompt, "language": config.LANGUAGE}) - chain_thread.start() - yield from chainStreamHandler.generate_tokens() if __name__ == "__main__": os.system("clear") @@ -535,13 +491,14 @@ def search_web_and_summary( # # 搜索 # for i in search_web_and_summary("今天的微博热搜有哪些?"): + # for i in search_web_and_summary("周海媚事件进展"): # for i in search_web_and_summary("macos 13.6 有什么新功能"): # for i in search_web_and_summary("用python写个网络爬虫给我"): # for i in search_web_and_summary("消失的她主要讲了什么?"): # for i in search_web_and_summary("奥巴马的全名是什么?"): # for i in search_web_and_summary("华为mate60怎么样?"): # for i in search_web_and_summary("慈禧养的猫叫什么名字?"): - for i in search_web_and_summary("民进党当初为什么支持柯文哲选台北市长?"): + # for i in search_web_and_summary("民进党当初为什么支持柯文哲选台北市长?"): # for i in search_web_and_summary("Has the United States won the china US trade war?"): # for i in search_web_and_summary("What does 'n+2' mean in Huawei's 'Mate 60 Pro' chipset? Please conduct in-depth analysis."): # for i in search_web_and_summary("AUTOMATIC1111 是什么?"): @@ -554,7 +511,7 @@ def search_web_and_summary( # for i in search_web_and_summary("金砖国家会议有哪些决定?"): # for i in search_web_and_summary("iphone15有哪些新功能?"): # for i in search_web_and_summary("python函数开头:def time(text: str) -> str:每个部分有什么用?"): - print(i, end="") + # print(i, end="") # 问答 # result = asyncio.run(docQA("/Users/yanyuming/Downloads/GitHub/wiki/docs", "ubuntu 版本号怎么看?")) diff --git a/utils/chatgpt2api.py b/utils/chatgpt2api.py index f5e1e7e6..645f7956 100644 --- a/utils/chatgpt2api.py +++ b/utils/chatgpt2api.py @@ -1,4 +1,5 @@ import os +import re import json from pathlib import Path from typing import AsyncGenerator @@ -11,7 +12,7 @@ from typing import Set import config -from utils.agent import Web_crawler, search_web_and_summary, get_search_results +from utils.agent import Web_crawler, get_search_results from utils.function_call import function_call_list def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]: @@ -276,10 +277,8 @@ def __init__( if "gpt-4-32k" in engine else 7000 if "gpt-4" in engine - else 4096 - if "gpt-3.5-turbo-1106" in engine - else 15000 - if "gpt-3.5-turbo-16k" in engine + else 16385 + if "gpt-3.5-turbo-1106" in engine or "gpt-3.5-turbo-16k" in engine else 99000 if "claude-2-web" in engine or "claude-2" in engine else 4000 @@ -297,7 +296,7 @@ def __init__( if "gpt-3.5-turbo-16k" in engine or "gpt-3.5-turbo-1106" in engine else 98500 if "claude-2-web" in engine or "claude-2" in engine - else 3400 + else 3500 ) self.temperature: float = temperature self.top_p: float = top_p @@ -360,6 +359,7 @@ def add_to_conversation( else: print('\033[31m') print("error: add_to_conversation message is None or empty") + print(self.conversation[convo_id]) print('\033[0m') def __truncate_conversation(self, convo_id: str = "default") -> None: @@ -372,10 +372,44 @@ def __truncate_conversation(self, convo_id: str = "default") -> None: and len(self.conversation[convo_id]) > 1 ): # Don't remove the first message - self.conversation[convo_id].pop(1) + mess = self.conversation[convo_id].pop(1) + print("Truncate message:", mess) else: break + def truncate_conversation( + self, + prompt: str, + role: str = "user", + convo_id: str = "default", + model: str = None, + pass_history: bool = True, + **kwargs, + ) -> None: + """ + Truncate the conversation + """ + while True: + json_post = self.get_post_body(prompt, role, convo_id, model, pass_history, **kwargs) + url = config.bot_api_url.chat_url + if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106": + message_token = { + "total": self.get_token_count(convo_id), + } + else: + message_token = self.get_message_token(url, json_post) + print("message_token", message_token, self.truncate_limit) + if ( + message_token["total"] > self.truncate_limit + and len(self.conversation[convo_id]) > 1 + ): + # Don't remove the first message + mess = self.conversation[convo_id].pop(1) + print("Truncate message:", mess) + else: + break + return json_post, message_token + # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb def get_token_count(self, convo_id: str = "default") -> int: """ @@ -385,9 +419,10 @@ def get_token_count(self, convo_id: str = "default") -> int: raise NotImplementedError( f"Engine {self.engine} is not supported. Select from {ENGINES}", ) - tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base" - tiktoken.model.MODEL_TO_ENCODING["claude-2-web"] = "cl100k_base" - tiktoken.model.MODEL_TO_ENCODING["claude-2"] = "cl100k_base" + # tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base" + # tiktoken.model.MODEL_TO_ENCODING["claude-2-web"] = "cl100k_base" + # tiktoken.model.MODEL_TO_ENCODING["claude-2"] = "cl100k_base" + tiktoken.get_encoding("cl100k_base") encoding = tiktoken.encoding_for_model(self.engine) @@ -402,6 +437,80 @@ def get_token_count(self, convo_id: str = "default") -> int: num_tokens += 5 # role is always required and always 1 token num_tokens += 5 # every reply is primed with assistant return num_tokens + + def get_message_token(self, url, json_post): + json_post["max_tokens"] = 5000 + headers = {"Authorization": f"Bearer {os.environ.get('API', None)}"} + response = requests.Session().post( + url, + headers=headers, + json=json_post, + timeout=None, + ) + if response.status_code != 200: + json_response = json.loads(response.text) + string = json_response["error"]["message"] + # print(json_response, string) + string = re.findall(r"\((.*?)\)", string)[0] + numbers = re.findall(r"\d+\.?\d*", string) + numbers = [int(i) for i in numbers] + if len(numbers) == 2: + return { + "messages": numbers[0], + "total": numbers[0], + } + elif len(numbers) == 3: + return { + "messages": numbers[0], + "functions": numbers[1], + "total": numbers[0] + numbers[1], + } + else: + raise Exception("Unknown error") + + def cut_message(self, message: str, max_tokens: int): + tiktoken.get_encoding("cl100k_base") + encoding = tiktoken.encoding_for_model(self.engine) + encode_text = encoding.encode(message) + if len(encode_text) > max_tokens: + encode_text = encode_text[:max_tokens] + message = encoding.decode(encode_text) + return message + + def get_post_body( + self, + prompt: str, + role: str = "user", + convo_id: str = "default", + model: str = None, + pass_history: bool = True, + **kwargs, + ): + json_post = { + "model": os.environ.get("MODEL_NAME") or model or self.engine, + "messages": self.conversation[convo_id] if pass_history else [{"role": "system","content": self.system_prompt},{"role": role, "content": prompt}], + "stream": True, + # kwargs + "temperature": kwargs.get("temperature", self.temperature), + "top_p": kwargs.get("top_p", self.top_p), + "presence_penalty": kwargs.get( + "presence_penalty", + self.presence_penalty, + ), + "frequency_penalty": kwargs.get( + "frequency_penalty", + self.frequency_penalty, + ), + "n": kwargs.get("n", self.reply_count), + "user": role, + "max_tokens": 5000, + } + json_post.update(function_call_list["base"]) + if config.SEARCH_USE_GPT: + json_post["functions"].append(function_call_list["web_search"]) + json_post["functions"].append(function_call_list["url_fetch"]) + + return json_post def get_max_tokens(self, convo_id: str) -> int: """ @@ -427,45 +536,18 @@ def ask_stream( if convo_id not in self.conversation or pass_history == False: self.reset(convo_id=convo_id, system_prompt=self.system_prompt) self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name) - self.__truncate_conversation(convo_id=convo_id) + json_post, message_token = self.truncate_conversation(prompt, role, convo_id, model, pass_history, **kwargs) # print(self.conversation[convo_id]) - # Get response - url = config.bot_api_url.chat_url - headers = {"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"} - if self.engine == "gpt-4-1106-preview": + if self.engine == "gpt-4-1106-preview" or self.engine == "gpt-3.5-turbo-1106": model_max_tokens = kwargs.get("max_tokens", self.max_tokens) - elif self.engine == "gpt-3.5-turbo-1106": - model_max_tokens = min(kwargs.get("max_tokens", self.max_tokens), self.truncate_limit - self.get_token_count(convo_id)) else: - model_max_tokens = min(self.get_max_tokens(convo_id=convo_id) - 500, kwargs.get("max_tokens", self.max_tokens)) - json_post = { - "model": os.environ.get("MODEL_NAME") or model or self.engine, - "messages": self.conversation[convo_id] if pass_history else [{"role": "system","content": self.system_prompt},{"role": role, "content": prompt}], - "stream": True, - # kwargs - "temperature": kwargs.get("temperature", self.temperature), - "top_p": kwargs.get("top_p", self.top_p), - "presence_penalty": kwargs.get( - "presence_penalty", - self.presence_penalty, - ), - "frequency_penalty": kwargs.get( - "frequency_penalty", - self.frequency_penalty, - ), - "n": kwargs.get("n", self.reply_count), - "user": role, - "max_tokens": model_max_tokens, - # "max_tokens": min( - # self.get_max_tokens(convo_id=convo_id), - # kwargs.get("max_tokens", self.max_tokens), - # ), - } - json_post.update(function_call_list["base"]) - if config.SEARCH_USE_GPT: - json_post["functions"].append(function_call_list["web_search"]) - json_post["functions"].append(function_call_list["url_fetch"]) + model_max_tokens = min(kwargs.get("max_tokens", self.max_tokens), self.max_tokens - message_token["total"]) + print("model_max_tokens", model_max_tokens) + json_post["max_tokens"] = model_max_tokens + + url = config.bot_api_url.chat_url + headers = {"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"} response = self.session.post( url, headers=headers, @@ -509,34 +591,37 @@ def ask_stream( function_call_name = delta["function_call"]["name"] full_response += function_call_content if need_function_call: - max_context_tokens = self.truncate_limit - self.get_token_count(convo_id) - 500 response_role = "function" + function_call_max_tokens = self.truncate_limit - message_token["total"] - 100 + if function_call_max_tokens <= 0: + function_call_max_tokens = int(self.truncate_limit / 2) + print("function_call_max_tokens", function_call_max_tokens) if function_call_name == "get_search_results": # g4t 提取的 prompt 有问题 # prompt = json.loads(full_response)["prompt"] for index in range(len(self.conversation[convo_id])): if self.conversation[convo_id][-1 - index]["role"] == "user": + self.conversation[convo_id][-1 - index]["content"] = self.conversation[convo_id][-1 - index]["content"].replace("search: ", "") prompt = self.conversation[convo_id][-1 - index]["content"] - print("prompt", prompt) + print("\n\nprompt", prompt) break # prompt = self.conversation[convo_id][-1]["content"] # print(self.truncate_limit, self.get_token_count(convo_id), max_context_tokens) - function_response = eval(function_call_name)(prompt, max_context_tokens) + function_response = eval(function_call_name)(prompt, function_call_max_tokens) function_response = "web search results: \n" + function_response yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name) - # yield from self.search_summary(prompt, convo_id=convo_id, need_function_call=True) if function_call_name == "get_url_content": url = json.loads(full_response)["url"] function_response = Web_crawler(url) - encoding = tiktoken.encoding_for_model(self.engine) - encode_text = encoding.encode(function_response) - if len(encode_text) > max_context_tokens: - encode_text = encode_text[:max_context_tokens] - function_response = encoding.decode(encode_text) + function_response = self.cut_message(function_response, function_call_max_tokens) yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name) else: self.add_to_conversation(full_response, response_role, convo_id=convo_id) - print("total tokens:", self.get_token_count(convo_id)) + # total_tokens = self.get_token_count(convo_id) + # completion_tokens = total_tokens - prompt_tokens + # print("completion tokens:", completion_tokens) + # print("total tokens:", total_tokens) + # print(self.conversation[convo_id]) async def ask_stream_async( self, @@ -666,25 +751,6 @@ def ask( ) full_response: str = "".join(response) return full_response - - def search_summary( - self, - prompt: str, - role: str = "user", - convo_id: str = "default", - pass_history: bool = True, - **kwargs, - ): - - if convo_id not in self.conversation: - self.reset(convo_id=convo_id, system_prompt=self.system_prompt) - self.add_to_conversation(prompt, role, convo_id=convo_id) - self.__truncate_conversation(convo_id=convo_id) - - full_response = yield from search_web_and_summary(prompt, self.engine, self.truncate_limit - self.get_token_count(convo_id)) - - self.add_to_conversation(full_response, "assistant", convo_id=convo_id) - print("total tokens:", self.get_token_count(convo_id)) def rollback(self, n: int = 1, convo_id: str = "default") -> None: """