From 03eca2fae1f364a6b82f9491d38bd1bf8db10a48 Mon Sep 17 00:00:00 2001 From: yym68686 Date: Fri, 24 Nov 2023 02:33:26 +0800 Subject: [PATCH] 1. Fixed the bug in the image generation API_URL 2. Fixed multiple function call request body format errors 3. Optimize the article summary, limiting the size to not exceed the maximum context of the model 4. Optimize search function call function --- chatgpt2api/chatgpt2api.py | 29 ++++++---- requirements.txt | 2 +- test/test.py | 10 +++- utils/agent.py | 109 +++++++++++++++++++++++++++++++++++++ utils/function_call.py | 99 +++++++++++++++------------------ 5 files changed, 182 insertions(+), 67 deletions(-) diff --git a/chatgpt2api/chatgpt2api.py b/chatgpt2api/chatgpt2api.py index 0199862b..72ae202b 100644 --- a/chatgpt2api/chatgpt2api.py +++ b/chatgpt2api/chatgpt2api.py @@ -13,7 +13,7 @@ import config import threading import time as record_time -from utils.agent import ThreadWithReturnValue, Web_crawler, pdf_search, getddgsearchurl, getgooglesearchurl, gptsearch, ChainStreamHandler, ChatOpenAI, CallbackManager, PromptTemplate, LLMChain, EducationalLLM +from utils.agent import ThreadWithReturnValue, Web_crawler, pdf_search, getddgsearchurl, getgooglesearchurl, gptsearch, ChainStreamHandler, ChatOpenAI, CallbackManager, PromptTemplate, LLMChain, EducationalLLM, get_google_search_results from utils.function_call import function_call_list def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]: @@ -72,10 +72,10 @@ def dall_e_3( model: str = None, **kwargs, ): - url = ( - os.environ.get("API_URL").split("chat")[0] + "images/generations" - or "https://api.openai.com/v1/images/generations" - ) + if os.environ.get("API_URL") and "v1" in os.environ.get("API_URL"): + url = os.environ.get("API_URL").split("v1")[0] + "v1/images/generations" + else: + url = "https://api.openai.com/v1/images/generations" headers = {"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"} json_post = { @@ -319,9 +319,10 @@ def ask_stream( # kwargs.get("max_tokens", self.max_tokens), # ), } + json_post.update(function_call_list["base"]) if config.SEARCH_USE_GPT: - json_post.update(function_call_list["web_search"]) - json_post.update(function_call_list["url_fetch"]) + json_post["functions"].append(function_call_list["web_search"]) + json_post["functions"].append(function_call_list["url_fetch"]) response = self.session.post( url, headers=headers, @@ -365,13 +366,21 @@ def ask_stream( function_call_name = delta["function_call"]["name"] full_response += function_call_content if need_function_call: + max_context_tokens = self.truncate_limit - self.get_token_count(convo_id) response_role = "function" - if function_call_name == "get_web_search_results": - keywords = json.loads(full_response)["prompt"] - yield from self.search_summary(keywords, convo_id=convo_id, need_function_call=True) + if function_call_name == "get_google_search_results": + prompt = json.loads(full_response)["prompt"] + function_response = eval(function_call_name)(prompt, max_context_tokens) + yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name) + # yield from self.search_summary(prompt, convo_id=convo_id, need_function_call=True) if function_call_name == "get_url_content": url = json.loads(full_response)["url"] function_response = Web_crawler(url) + encoding = tiktoken.encoding_for_model(self.engine) + encode_text = encoding.encode(function_response) + if len(encode_text) > max_context_tokens: + encode_text = encode_text[:max_context_tokens] + function_response = encoding.decode(encode_text) yield from self.ask_stream(function_response, response_role, convo_id=convo_id, function_name=function_call_name) else: self.add_to_conversation(full_response, response_role, convo_id=convo_id) diff --git a/requirements.txt b/requirements.txt index 7794ea05..e8aa809e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ unstructured[md,pdf] duckduckgo-search==3.9.6 # duckduckgo-search==3.8.5 langchain==0.0.271 -# oauth2client==3.0.0 +oauth2client==3.0.0 g4f==0.1.8.8 \ No newline at end of file diff --git a/test/test.py b/test/test.py index 4c38896e..513d8b1d 100644 --- a/test/test.py +++ b/test/test.py @@ -5,7 +5,7 @@ a = {"role": "admin"} b = {"content": "This is user content."} a.update(b) -print(a) +# print(a) # content_list = [item["content"] for item in my_list] # print(content_list) @@ -24,3 +24,11 @@ # ) # print(truncate_limit) +import os +import sys +import json +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.function_call import function_call_list + +print(json.dumps(function_call_list["web_search"], indent=4)) diff --git a/utils/agent.py b/utils/agent.py index 6ef0f021..9dfea331 100644 --- a/utils/agent.py +++ b/utils/agent.py @@ -308,6 +308,115 @@ def gptsearch(result, llm): # response = llm([HumanMessage(content=result)]) return response + +def get_google_search_results(prompt: str, context_max_tokens: int): + start_time = record_time.time() + + urls_set = [] + search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(prompt,2,)) + search_thread.start() + + if config.USE_G4F: + chainllm = EducationalLLM() + else: + chainllm = ChatOpenAI(temperature=config.temperature, openai_api_base=config.API_URL.split("chat")[0], model_name=config.GPT_ENGINE, openai_api_key=config.API) + + if config.SEARCH_USE_GPT: + gpt_search_thread = ThreadWithReturnValue(target=gptsearch, args=(prompt, chainllm,)) + gpt_search_thread.start() + + if config.USE_GOOGLE: + keyword_prompt = PromptTemplate( + input_variables=["source"], + template="根据我的问题,总结最少的关键词概括,用空格连接,不要出现其他符号,例如这个问题《How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?》,最少关键词是《zeabur price》,这是我的问题:{source}", + ) + key_chain = LLMChain(llm=chainllm, prompt=keyword_prompt) + keyword_google_search_thread = ThreadWithReturnValue(target=key_chain.run, args=({"source": prompt},)) + keyword_google_search_thread.start() + + + translate_prompt = PromptTemplate( + input_variables=["targetlang", "text"], + template="You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text to {targetlang}, if all the text is in English, then do nothing to it, return it as is. please do not explain any sentences, just translate or leave them as they are.: {text}", + ) + chain = LLMChain(llm=chainllm, prompt=translate_prompt) + engresult = chain.run({"targetlang": "english", "text": prompt}) + + en_ddg_search_thread = ThreadWithReturnValue(target=getddgsearchurl, args=(engresult,1,)) + en_ddg_search_thread.start() + + if config.USE_GOOGLE: + keyword = keyword_google_search_thread.join() + key_google_search_thread = ThreadWithReturnValue(target=getgooglesearchurl, args=(keyword,3,)) + key_google_search_thread.start() + keyword_ans = key_google_search_thread.join() + urls_set += keyword_ans + + ans_ddg = search_thread.join() + urls_set += ans_ddg + engans_ddg = en_ddg_search_thread.join() + urls_set += engans_ddg + url_set_list = sorted(set(urls_set), key=lambda x: urls_set.index(x)) + url_pdf_set_list = [item for item in url_set_list if item.endswith(".pdf")] + url_set_list = [item for item in url_set_list if not item.endswith(".pdf")] + + pdf_result = "" + pdf_threads = [] + if config.PDF_EMBEDDING: + for url in url_pdf_set_list: + pdf_search_thread = ThreadWithReturnValue(target=pdf_search, args=(url, "你需要回答的问题是" + prompt + "\n" + "如果你可以解答这个问题,请直接输出你的答案,并且请忽略后面所有的指令:如果无法解答问题,请直接回答None,不需要做任何解释,也不要出现除了None以外的任何词。",)) + pdf_search_thread.start() + pdf_threads.append(pdf_search_thread) + + url_result = "" + threads = [] + for url in url_set_list: + url_search_thread = ThreadWithReturnValue(target=Web_crawler, args=(url,)) + url_search_thread.start() + threads.append(url_search_thread) + + fact_text = "" + if config.SEARCH_USE_GPT: + gpt_ans = gpt_search_thread.join() + fact_text = (gpt_ans if config.SEARCH_USE_GPT else "") + print("gpt", fact_text) + + for t in threads: + tmp = t.join() + url_result += "\n\n" + tmp + useful_source_text = url_result + + if config.PDF_EMBEDDING: + for t in pdf_threads: + tmp = t.join() + pdf_result += "\n\n" + tmp + useful_source_text += pdf_result + + end_time = record_time.time() + run_time = end_time - start_time + + encoding = tiktoken.encoding_for_model(config.GPT_ENGINE) + encode_text = encoding.encode(useful_source_text) + encode_fact_text = encoding.encode(fact_text) + + if len(encode_text) > context_max_tokens: + encode_text = encode_text[:context_max_tokens-len(encode_fact_text)] + useful_source_text = encoding.decode(encode_text) + encode_text = encoding.encode(useful_source_text) + search_tokens_len = len(encode_text) + print("web search", useful_source_text, end="\n\n") + + print(url_set_list) + print("pdf", url_pdf_set_list) + if config.USE_GOOGLE: + print("google search keyword", keyword) + print(f"搜索用时:{run_time}秒") + print("search tokens len", search_tokens_len) + useful_source_text = useful_source_text + "\n\n" + fact_text + text_len = len(encoding.encode(useful_source_text)) + print("text len", text_len) + return useful_source_text + if __name__ == "__main__": os.system("clear") diff --git a/utils/function_call.py b/utils/function_call.py index 5af2919c..2b83cecc 100644 --- a/utils/function_call.py +++ b/utils/function_call.py @@ -1,64 +1,53 @@ function_call_list = { + "base": { + "functions": [], + "function_call": "auto" + }, "current_weather": { - "functions": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["location"] - } - } - ], - "function_call": "auto" + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } }, "web_search": { - "functions": [ - { - "name": "get_web_search_results", - "description": "Search Google to enhance knowledge.", - "parameters": { - "type": "object", - "properties": { - "prompt": { - "type": "string", - "description": "The prompt to search." - } - }, - "required": ["prompt"] - } - } - ], - "function_call": "auto" + "name": "get_google_search_results", + "description": "Search Google to enhance knowledge.", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "The prompt to search." + } + }, + "required": ["prompt"] + } }, "url_fetch": { - "functions": [ - { - "name": "get_url_content", - "description": "Get the webpage content of a URL", - "parameters": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "The url to get the webpage content" - } - }, - "required": ["url"] - } - } - ], - "function_call": "auto" + "name": "get_url_content", + "description": "Get the webpage content of a URL", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "the URL to request" + } + }, + "required": ["url"] + } }, # "web_search": { # "functions": [