From 5f95832c4b083a183a3c42eff9835088888bb55a Mon Sep 17 00:00:00 2001 From: yym68686 Date: Tue, 19 Sep 2023 11:55:31 +0800 Subject: [PATCH] Add feature: support uploading PDF, automatically perform vector search on PDF documents, and extract content related to PDF based on vector database --- agent.py | 20 +++++--- bot.py | 140 ++++++++++++++++++++++++------------------------------- 2 files changed, 74 insertions(+), 86 deletions(-) diff --git a/agent.py b/agent.py index b22f182c..2487ac4e 100644 --- a/agent.py +++ b/agent.py @@ -133,7 +133,9 @@ def get_doc_from_url(url): f.write(chunk) return filename -def persist_emdedding_pdf(docpath, persist_db_path, embeddings): +def persist_emdedding_pdf(docurl, persist_db_path, embeddings): + filename = get_doc_from_url(docurl) + docpath = os.getcwd() + "/" + filename loader = UnstructuredPDFLoader(docpath) documents = loader.load() # 初始化加载器 @@ -142,18 +144,20 @@ def persist_emdedding_pdf(docpath, persist_db_path, embeddings): split_docs = text_splitter.split_documents(documents) vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path) vector_store.persist() + os.remove(docpath) return vector_store -def pdfQA(docpath, query_message, model="gpt-3.5-turbo"): +async def pdfQA(docurl, query_message, model="gpt-3.5-turbo"): chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None)) embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=os.environ.get('API', None)) - persist_db_path = getmd5(docpath) + persist_db_path = getmd5(docurl) if not os.path.exists(persist_db_path): - vector_store = persist_emdedding_pdf(docpath, persist_db_path, embeddings) + vector_store = persist_emdedding_pdf(docurl, persist_db_path, embeddings) else: vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings) - qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True) + qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True) result = qa({"query": query_message}) + print(2) return result['result'] def pdf_search(docurl, query_message, model="gpt-3.5-turbo"): @@ -426,8 +430,10 @@ def search_summary(result, model=config.DEFAULT_SEARCH_MODEL, temperature=config # 问答 # result = asyncio.run(docQA("/Users/yanyuming/Downloads/GitHub/wiki/docs", "ubuntu 版本号怎么看?")) # result = asyncio.run(docQA("https://yym68686.top", "说一下HSTL pipeline")) - result = asyncio.run(docQA("https://wiki.yym68686.top", "PyTorch to MindSpore翻译思路是什么?")) - print(result['answer']) + # result = asyncio.run(docQA("https://wiki.yym68686.top", "PyTorch to MindSpore翻译思路是什么?")) + # print(result['answer']) + result = asyncio.run(pdfQA("https://api.telegram.org/file/bot5569497961:AAHobhUuydAwD8SPkXZiVFybvZJOmGrST_w/documents/file_1.pdf", "HSTL的pipeline详细讲一下")) + print(result) # source_url = set([i.metadata['source'] for i in result["source_documents"]]) # source_url = "\n".join(source_url) # message = ( diff --git a/bot.py b/bot.py index c8b27996..48d0d03c 100644 --- a/bot.py +++ b/bot.py @@ -112,51 +112,6 @@ async def delete_message(update, context, messageid, delay=10): print("error", e) print('\033[0m') -# async def history(update, context): -# config.PASS_HISTORY = not config.PASS_HISTORY -# status = "打开" if config.PASS_HISTORY else "关闭" -# message = ( -# f"当前已{status}聊天记录!\n" -# f"**PASS_HISTORY:** `{config.PASS_HISTORY}`" -# ) -# message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2') - -# messageid = message.message_id -# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id) -# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid),)) -# thread.start() - -# async def gpt_use_search(update, context): -# config.SEARCH_USE_GPT = not config.SEARCH_USE_GPT -# status = "打开" if config.SEARCH_USE_GPT else "关闭" -# message = ( -# f"当前已{status}gpt默认搜索🔍!\n" -# f"**SEARCH_USE_GPT:** `{config.SEARCH_USE_GPT}`" -# ) -# message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2') - -# messageid = message.message_id -# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id) -# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid),)) -# thread.start() - -# async def google(update, context): -# if os.environ.get('GOOGLE_API_KEY', None) == None and os.environ.get('GOOGLE_CSE_ID', None) == None: -# await context.bot.send_message(chat_id=update.message.chat_id, text=escape("GOOGLE_API_KEY or GOOGLE_CSE_ID not found"), parse_mode='MarkdownV2') -# return -# config.USE_GOOGLE = not config.USE_GOOGLE -# status = "打开" if config.USE_GOOGLE else "关闭" -# message = ( -# f"当前已{status}google搜索!\n" -# f"**USE_GOOGLE:** `{config.USE_GOOGLE}`" -# ) -# message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2') - -# messageid = message.message_id -# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id) -# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid),)) -# thread.start() - buttons = [ [ InlineKeyboardButton("gpt-3.5-turbo", callback_data="gpt-3.5-turbo"), @@ -386,10 +341,10 @@ async def search(update, context, has_command=True): ) from agent import pdfQA, getmd5 -def handle_pdf(update, context): +async def handle_pdf(update, context): # 获取接收到的文件 pdf_file = update.message.document - question = update.message.text + question = update.message.caption # 下载文件到本地 file_name = pdf_file.file_name @@ -397,41 +352,68 @@ def handle_pdf(update, context): match_embedding = os.path.exists(getmd5(docpath)) if not match_embedding: file_id = pdf_file.file_id - new_file = context.bot.get_file(file_id) - new_file.download(docpath) - result = pdfQA(docpath, question) - if not match_embedding: - os.remove(docpath) - context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) + new_file = await context.bot.get_file(file_id) + file_url = new_file.file_path + result = await pdfQA(file_url, question) + print(result) + await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) async def qa(update, context): - if (len(context.args) != 2): + if update.message.reply_to_message is None: + if (len(context.args) != 2): + message = ( + f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n" + f"请输入 `/qa 知识库链接 要问的问题`\n\n" + f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n" + f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n" + f"问题务必不能有空格,👆点击上方命令复制格式\n\n" + f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n" + f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n" + f"问题务必不能有空格,👆点击上方命令复制格式\n\n" + f"本地知识库目前只支持 Markdown 文件\n\n" + ) + await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) + return + print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") + await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING) + result = await docQA(context.args[0], context.args[1], get_doc_from_local) + source_url = set([i.metadata['source'] for i in result["source_documents"]]) + source_url = "\n".join(source_url) message = ( - f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n" - f"请输入 `/qa 知识库链接 要问的问题`\n\n" - f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n" - f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n" - f"问题务必不能有空格,👆点击上方命令复制格式\n\n" - f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n" - f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n" - f"问题务必不能有空格,👆点击上方命令复制格式\n\n" - f"本地知识库目前只支持 Markdown 文件\n\n" + f"{result['result']}\n\n" + f"参考链接:\n" + f"{source_url}" ) - await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) - return - print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") - await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING) - # result = docQA(context.args[0], context.args[1], get_doc_from_sitemap) - result = await docQA(context.args[0], context.args[1], get_doc_from_local) - source_url = set([i.metadata['source'] for i in result["source_documents"]]) - source_url = "\n".join(source_url) - message = ( - f"{result['result']}\n\n" - f"参考链接:\n" - f"{source_url}" - ) - print(message) - await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) + print(message) + await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) + else: + if update.message.reply_to_message.document is None: + message = ( + f"格式错误哦~,需要回复一个文件,我才知道你要针对哪个文件提问,注意命令与问题之间的空格\n\n" + f"请输入 `/qa 要问的问题`\n\n" + f"例如已经上传某文档 ,问题是 蘑菇怎么分类?\n\n" + f"先左滑文档进入回复模式,在聊天框里面输入 `/qa 蘑菇怎么分类?`\n\n" + f"问题务必不能有空格,👆点击上方命令复制格式\n\n" + # f"本地知识库目前只支持 Markdown 文件\n\n" + ) + await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True) + return + print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m") + await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING) + pdf_file = update.message.reply_to_message.document + question = update.message.text + # 下载文件到本地 + file_name = pdf_file.file_name + docpath = os.getcwd() + "/" + file_name + match_embedding = os.path.exists(getmd5(docpath)) + if not match_embedding: + file_id = pdf_file.file_id + new_file = context.bot.get_file(file_id) + new_file.download(docpath) + result = await pdfQA(docpath, question) + if not match_embedding: + os.remove(docpath) + context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True) async def start(update, context): # 当用户输入/start时,返回文本 user = update.effective_user @@ -472,7 +454,7 @@ def setup(token): application.add_handler(CommandHandler("zh2en", lambda update, context: command_bot(update, context, "english", robot=config.ChatGPTbot))) application.add_handler(CommandHandler("info", info)) application.add_handler(CommandHandler("qa", qa)) - application.add_handler(MessageHandler(filters.Document.MimeType('application/pdf') & filters.TEXT, handle_pdf)) + application.add_handler(MessageHandler(filters.Document.MimeType('application/pdf'), handle_pdf)) application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False))) application.add_handler(MessageHandler(filters.COMMAND, unknown)) application.add_error_handler(error)