From 3610fd9f697ab839649cbba4a68cf0726028edcf Mon Sep 17 00:00:00 2001 From: yym68686 Date: Mon, 18 Sep 2023 21:55:22 +0800 Subject: [PATCH] fixed bug: doc qa reach openai maximum context length --- agent.py | 175 +++++++++++++++++++++++++++++-------------------------- bot.py | 27 +++++++-- 2 files changed, 114 insertions(+), 88 deletions(-) diff --git a/agent.py b/agent.py index 41f686ae..b22f182c 100644 --- a/agent.py +++ b/agent.py @@ -2,6 +2,7 @@ import re import config import chardet +import asyncio import tiktoken import requests import threading @@ -13,7 +14,7 @@ from bs4 import BeautifulSoup from langchain.llms import OpenAI -from langchain.chains import LLMChain, RetrievalQA +from langchain.chains import LLMChain, RetrievalQA, RetrievalQAWithSourcesChain from langchain.agents import AgentType, load_tools, initialize_agent, tool from langchain.schema import HumanMessage from langchain.schema.output import LLMResult @@ -57,6 +58,41 @@ async def get_doc_from_local(docpath, doctype="md"): documents = loader.load() return documents +system_template="""Use the following pieces of context to answer the users question. +If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer. +ALWAYS return a "Sources" part in your answer. +The "Sources" part should be a reference to the source of the document from which you got your answer. + +Example of your response should be: + +``` +The answer is foo + +Sources: +1. abc +2. xyz +``` +Begin! +---------------- +{summaries} +""" +messages = [ + SystemMessagePromptTemplate.from_template(system_template), + HumanMessagePromptTemplate.from_template("{question}") +] +prompt = ChatPromptTemplate.from_messages(messages) + +def get_chain(store, llm): + chain_type_kwargs = {"prompt": prompt} + chain = RetrievalQAWithSourcesChain.from_chain_type( + llm, + chain_type="stuff", + retriever=store.as_retriever(), + chain_type_kwargs=chain_type_kwargs, + reduce_k_below_max_tokens=True + ) + return chain + async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"): chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=config.API) embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=config.API) @@ -74,9 +110,8 @@ async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-t documents = await doc_method(docpath) # 初始化加载器 text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50) - # 切割加载的 document - split_docs = text_splitter.split_documents(documents) # 持久化数据 + split_docs = text_splitter.split_documents(documents) vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) vector_store.persist() else: @@ -84,59 +119,12 @@ async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-t vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) # 创建问答对象 - qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True) + qa = get_chain(vector_store, chatllm) + # qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True) # 进行问答 - result = qa({"query": query_message}) + result = qa({"question": query_message}) return result -@tool -def time(text: str) -> str: - """Returns todays date, use this for any \ - questions related to knowing todays date. \ - The input should always be an empty string, \ - and this function will always return todays \ - date - any date mathmatics should occur \ - outside this function.""" - return str(date.today()) - -def today_date(): - return str(date.today()) - -def duckduckgo_search(result, model="gpt-3.5-turbo", temperature=0.5): - try: - translate_prompt = PromptTemplate( - input_variables=["targetlang", "text"], - template="You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text to {targetlang}, please do not explain any sentences, just translate or leave them as they are.: {text}", - ) - chatllm = ChatOpenAI(temperature=temperature, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=config.API) - - # # 翻译成英文 带聊天模型的链 方法一 - # translate_template="You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text from {sourcelang} to {targetlang}, please do not explain any sentences, just translate or leave them as they are." - # system_message_prompt = SystemMessagePromptTemplate.from_template(translate_template) - # human_template = "{text}" - # human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) - # chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) - # zh2enchain = LLMChain(llm=chatllm, prompt=chat_prompt) - # result = zh2enchain.run(sourcelang="Simplified Chinese", targetlang="English", text=searchtext) - - # # 翻译成英文 方法二 - # chain = LLMChain(llm=chatllm, prompt=translate_prompt) - # result = chain.run({"targetlang": "english", "text": searchtext}) - - # 搜索 - tools = load_tools(["ddg-search", "llm-math", "wikipedia"], llm=chatllm) - agent = initialize_agent(tools + [time], chatllm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, max_iterations=2, early_stopping_method="generate", handle_parsing_errors=True) - result = agent.run(result) - - # 翻译成中文 - en2zhchain = LLMChain(llm=chatllm, prompt=translate_prompt) - result = en2zhchain.run(targetlang="Simplified Chinese", text=result) - result = en2zhchain.run({"targetlang": "Simplified Chinese", "text": result}) - - return result - except Exception as e: - traceback.print_exc() - def get_doc_from_url(url): filename = url.split("/")[-1] response = requests.get(url, stream=True) @@ -145,10 +133,33 @@ def get_doc_from_url(url): f.write(chunk) return filename -def pdf_search(docpath, query_message, model="gpt-3.5-turbo"): +def persist_emdedding_pdf(docpath, persist_db_path, embeddings): + loader = UnstructuredPDFLoader(docpath) + documents = loader.load() + # 初始化加载器 + text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25) + # 切割加载的 document + split_docs = text_splitter.split_documents(documents) + vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) + vector_store.persist() + return vector_store + +def pdfQA(docpath, query_message, model="gpt-3.5-turbo"): + chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None)) + embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=os.environ.get('API', None)) + persist_db_path = getmd5(docpath) + if not os.path.exists(persist_db_path): + vector_store = persist_emdedding_pdf(docpath, persist_db_path, embeddings) + else: + vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) + qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True) + result = qa({"query": query_message}) + return result['result'] + +def pdf_search(docurl, query_message, model="gpt-3.5-turbo"): chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None)) embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=os.environ.get('API', None)) - filename = get_doc_from_url(docpath) + filename = get_doc_from_url(docurl) docpath = os.getcwd() + "/" + filename loader = UnstructuredPDFLoader(docpath) documents = loader.load() @@ -389,34 +400,34 @@ def search_summary(result, model=config.DEFAULT_SEARCH_MODEL, temperature=config # from langchain.agents import get_all_tool_names # print(get_all_tool_names()) - # 搜索 - # print(duckduckgo_search("凡凡还有多久出狱?")) - # print(search_summary("凡凡还有多久出狱?")) - - # for i in search_summary("今天的微博热搜有哪些?"): - # for i in search_summary("用python写个网络爬虫给我"): - # for i in search_summary("消失的她主要讲了什么?"): - # for i in search_summary("奥巴马的全名是什么?"): - # for i in search_summary("华为mate60怎么样?"): - # for i in search_summary("慈禧养的猫叫什么名字?"): - # for i in search_summary("民进党当初为什么支持柯文哲选台北市长?"): - # for i in search_summary("Has the United States won the china US trade war?"): - # for i in search_summary("What does 'n+2' mean in Huawei's 'Mate 60 Pro' chipset? Please conduct in-depth analysis."): - # for i in search_summary("AUTOMATIC1111 是什么?"): - for i in search_summary("python telegram bot 怎么接收pdf文件"): - # for i in search_summary("中国利用外资指标下降了 87% ?真的假的。"): - # for i in search_summary("How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"): - # for i in search_summary("英国脱欧没有好处,为什么英国人还是要脱欧?"): - # for i in search_summary("2022年俄乌战争为什么发生?"): - # for i in search_summary("卡罗尔与星期二讲的啥?"): - # for i in search_summary("金砖国家会议有哪些决定?"): - # for i in search_summary("iphone15有哪些新功能?"): - # for i in search_summary("python函数开头:def time(text: str) -> str:每个部分有什么用?"): - print(i, end="") - - # # 问答 + # # 搜索 + + # # for i in search_summary("今天的微博热搜有哪些?"): + # # for i in search_summary("用python写个网络爬虫给我"): + # # for i in search_summary("消失的她主要讲了什么?"): + # # for i in search_summary("奥巴马的全名是什么?"): + # # for i in search_summary("华为mate60怎么样?"): + # # for i in search_summary("慈禧养的猫叫什么名字?"): + # # for i in search_summary("民进党当初为什么支持柯文哲选台北市长?"): + # # for i in search_summary("Has the United States won the china US trade war?"): + # # for i in search_summary("What does 'n+2' mean in Huawei's 'Mate 60 Pro' chipset? Please conduct in-depth analysis."): + # # for i in search_summary("AUTOMATIC1111 是什么?"): + # for i in search_summary("python telegram bot 怎么接收pdf文件"): + # # for i in search_summary("中国利用外资指标下降了 87% ?真的假的。"): + # # for i in search_summary("How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"): + # # for i in search_summary("英国脱欧没有好处,为什么英国人还是要脱欧?"): + # # for i in search_summary("2022年俄乌战争为什么发生?"): + # # for i in search_summary("卡罗尔与星期二讲的啥?"): + # # for i in search_summary("金砖国家会议有哪些决定?"): + # # for i in search_summary("iphone15有哪些新功能?"): + # # for i in search_summary("python函数开头:def time(text: str) -> str:每个部分有什么用?"): + # print(i, end="") + + # 问答 # result = asyncio.run(docQA("/Users/yanyuming/Downloads/GitHub/wiki/docs", "ubuntu 版本号怎么看?")) - # # result = asyncio.run(docQA("https://yym68686.top", "reid可以怎么分类?")) + # result = asyncio.run(docQA("https://yym68686.top", "说一下HSTL pipeline")) + result = asyncio.run(docQA("https://wiki.yym68686.top", "PyTorch to MindSpore翻译思路是什么?")) + print(result['answer']) # source_url = set([i.metadata['source'] for i in result["source_documents"]]) # source_url = "\n".join(source_url) # message = ( diff --git a/bot.py b/bot.py index 4608addc..522c4419 100644 --- a/bot.py +++ b/bot.py @@ -89,7 +89,7 @@ async def getChatGPT(title, robot, message, update, context): print("error", e) print('\033[0m') if config.API: - robot.reset(convo_id=str(update.message.chat_id), system_prompt=systemprompt) + robot.reset(convo_id=str(update.message.chat_id), system_prompt=config.systemprompt) if "You exceeded your current quota, please check your plan and billing details." in str(e): print("OpenAI api 已过期!") await context.bot.delete_message(chat_id=update.message.chat_id, message_id=messageid) @@ -288,6 +288,25 @@ async def search(update, context, has_command=True): reply_to_message_id=update.message.message_id, ) +from agent import pdfQA, getmd5 +def handle_pdf(update, context): + # 获取接收到的文件 + pdf_file = update.message.document + question = update.message.text + + # 下载文件到本地 + file_name = pdf_file.file_name + docpath = os.getcwd() + "/" + file_name + match_embedding = os.path.exists(getmd5(docpath)) + if not match_embedding: + file_id = pdf_file.file_id + new_file = context.bot.get_file(file_id) + new_file.download(docpath) + result = pdfQA(docpath, question) + if not match_embedding: + os.remove(docpath) + context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) + async def qa(update, context): if (len(context.args) != 2): message = ( @@ -338,9 +357,6 @@ def setup(token): application = ApplicationBuilder().read_timeout(10).connection_pool_size(50000).pool_timeout(1200.0).token(token).build() run_async(application.bot.set_my_commands([ - # BotCommand('gpt4', 'use gpt4'), - # BotCommand('claude2', 'use claude2'), - # BotCommand('search', 'search the web with google and duckduckgo'), BotCommand('info', 'basic information'), BotCommand('gpt_use_search', 'open or close gpt use search'), BotCommand('history', 'open or close chat history'), @@ -357,13 +373,12 @@ def setup(token): application.add_handler(CommandHandler("reset", reset_chat)) application.add_handler(CommandHandler("en2zh", lambda update, context: command_bot(update, context, "simplified chinese", robot=config.ChatGPTbot))) application.add_handler(CommandHandler("zh2en", lambda update, context: command_bot(update, context, "english", robot=config.ChatGPTbot))) - # application.add_handler(CommandHandler("gpt4", lambda update, context: command_bot(update, context, prompt=None, title="`🤖️ gpt-4`\n\n", robot=config.ChatGPT4bot))) - # application.add_handler(CommandHandler("claude2", lambda update, context: command_bot(update, context, prompt=None, title="`🤖️ claude2`\n\n", robot=config.Claude2bot))) application.add_handler(CommandHandler("info", info)) application.add_handler(CommandHandler("history", history)) application.add_handler(CommandHandler("google", google)) application.add_handler(CommandHandler("gpt_use_search", gpt_use_search)) application.add_handler(CommandHandler("qa", qa)) + application.add_handler(MessageHandler(filters.Document.mime_type('application/pdf') & filters.TEXT, handle_pdf)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False))) application.add_handler(MessageHandler(filters.COMMAND, unknown)) application.add_error_handler(error)