Skip to content

Commit

Permalink
Add feature: support uploading PDF, automatically perform vector sear…
Browse files Browse the repository at this point in the history
…ch on PDF documents, and extract content related to PDF based on vector database
  • Loading branch information
yym68686 committed Sep 19, 2023
1 parent 46c7b05 commit 5f95832
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 86 deletions.
20 changes: 13 additions & 7 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def get_doc_from_url(url):
f.write(chunk)
return filename

def persist_emdedding_pdf(docpath, persist_db_path, embeddings):
def persist_emdedding_pdf(docurl, persist_db_path, embeddings):
filename = get_doc_from_url(docurl)
docpath = os.getcwd() + "/" + filename
loader = UnstructuredPDFLoader(docpath)
documents = loader.load()
# 初始化加载器
Expand All @@ -142,18 +144,20 @@ def persist_emdedding_pdf(docpath, persist_db_path, embeddings):
split_docs = text_splitter.split_documents(documents)
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
vector_store.persist()
os.remove(docpath)
return vector_store

def pdfQA(docpath, query_message, model="gpt-3.5-turbo"):
async def pdfQA(docurl, query_message, model="gpt-3.5-turbo"):
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None))
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=os.environ.get('API', None))
persist_db_path = getmd5(docpath)
persist_db_path = getmd5(docurl)
if not os.path.exists(persist_db_path):
vector_store = persist_emdedding_pdf(docpath, persist_db_path, embeddings)
vector_store = persist_emdedding_pdf(docurl, persist_db_path, embeddings)
else:
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
result = qa({"query": query_message})
print(2)
return result['result']

def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
Expand Down Expand Up @@ -426,8 +430,10 @@ def search_summary(result, model=config.DEFAULT_SEARCH_MODEL, temperature=config
# 问答
# result = asyncio.run(docQA("/Users/yanyuming/Downloads/GitHub/wiki/docs", "ubuntu 版本号怎么看?"))
# result = asyncio.run(docQA("https://yym68686.top", "说一下HSTL pipeline"))
result = asyncio.run(docQA("https://wiki.yym68686.top", "PyTorch to MindSpore翻译思路是什么?"))
print(result['answer'])
# result = asyncio.run(docQA("https://wiki.yym68686.top", "PyTorch to MindSpore翻译思路是什么?"))
# print(result['answer'])
result = asyncio.run(pdfQA("https://api.telegram.org/file/bot5569497961:AAHobhUuydAwD8SPkXZiVFybvZJOmGrST_w/documents/file_1.pdf", "HSTL的pipeline详细讲一下"))
print(result)
# source_url = set([i.metadata['source'] for i in result["source_documents"]])
# source_url = "\n".join(source_url)
# message = (
Expand Down
140 changes: 61 additions & 79 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,51 +112,6 @@ async def delete_message(update, context, messageid, delay=10):
print("error", e)
print('\033[0m')

# async def history(update, context):
# config.PASS_HISTORY = not config.PASS_HISTORY
# status = "打开" if config.PASS_HISTORY else "关闭"
# message = (
# f"当前已{status}聊天记录!\n"
# f"**PASS_HISTORY:** `{config.PASS_HISTORY}`"
# )
# message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2')

# messageid = message.message_id
# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id)
# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid),))
# thread.start()

# async def gpt_use_search(update, context):
# config.SEARCH_USE_GPT = not config.SEARCH_USE_GPT
# status = "打开" if config.SEARCH_USE_GPT else "关闭"
# message = (
# f"当前已{status}gpt默认搜索🔍!\n"
# f"**SEARCH_USE_GPT:** `{config.SEARCH_USE_GPT}`"
# )
# message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2')

# messageid = message.message_id
# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id)
# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid),))
# thread.start()

# async def google(update, context):
# if os.environ.get('GOOGLE_API_KEY', None) == None and os.environ.get('GOOGLE_CSE_ID', None) == None:
# await context.bot.send_message(chat_id=update.message.chat_id, text=escape("GOOGLE_API_KEY or GOOGLE_CSE_ID not found"), parse_mode='MarkdownV2')
# return
# config.USE_GOOGLE = not config.USE_GOOGLE
# status = "打开" if config.USE_GOOGLE else "关闭"
# message = (
# f"当前已{status}google搜索!\n"
# f"**USE_GOOGLE:** `{config.USE_GOOGLE}`"
# )
# message = await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2')

# messageid = message.message_id
# await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id)
# thread = threading.Thread(target=run_async, args=(delete_message(update, context, messageid),))
# thread.start()

buttons = [
[
InlineKeyboardButton("gpt-3.5-turbo", callback_data="gpt-3.5-turbo"),
Expand Down Expand Up @@ -386,52 +341,79 @@ async def search(update, context, has_command=True):
)

from agent import pdfQA, getmd5
def handle_pdf(update, context):
async def handle_pdf(update, context):
# 获取接收到的文件
pdf_file = update.message.document
question = update.message.text
question = update.message.caption

# 下载文件到本地
file_name = pdf_file.file_name
docpath = os.getcwd() + "/" + file_name
match_embedding = os.path.exists(getmd5(docpath))
if not match_embedding:
file_id = pdf_file.file_id
new_file = context.bot.get_file(file_id)
new_file.download(docpath)
result = pdfQA(docpath, question)
if not match_embedding:
os.remove(docpath)
context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)
new_file = await context.bot.get_file(file_id)
file_url = new_file.file_path
result = await pdfQA(file_url, question)
print(result)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

async def qa(update, context):
if (len(context.args) != 2):
if update.message.reply_to_message is None:
if (len(context.args) != 2):
message = (
f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n"
f"请输入 `/qa 知识库链接 要问的问题`\n\n"
f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n"
f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n"
f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
f"本地知识库目前只支持 Markdown 文件\n\n"
)
await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
return
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
result = await docQA(context.args[0], context.args[1], get_doc_from_local)
source_url = set([i.metadata['source'] for i in result["source_documents"]])
source_url = "\n".join(source_url)
message = (
f"格式错误哦~,需要两个参数,注意路径或者链接、问题之间的空格\n\n"
f"请输入 `/qa 知识库链接 要问的问题`\n\n"
f"例如知识库链接为 https://abc.com ,问题是 蘑菇怎么分类?\n\n"
f"则输入 `/qa https://abc.com 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
f"除了输入网址,同时支持本地知识库,本地知识库文件夹路径为 `./wiki`,问题是 蘑菇怎么分类?\n\n"
f"则输入 `/qa ./wiki 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
f"本地知识库目前只支持 Markdown 文件\n\n"
f"{result['result']}\n\n"
f"参考链接:\n"
f"{source_url}"
)
await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
return
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
# result = docQA(context.args[0], context.args[1], get_doc_from_sitemap)
result = await docQA(context.args[0], context.args[1], get_doc_from_local)
source_url = set([i.metadata['source'] for i in result["source_documents"]])
source_url = "\n".join(source_url)
message = (
f"{result['result']}\n\n"
f"参考链接:\n"
f"{source_url}"
)
print(message)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
print(message)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
else:
if update.message.reply_to_message.document is None:
message = (
f"格式错误哦~,需要回复一个文件,我才知道你要针对哪个文件提问,注意命令与问题之间的空格\n\n"
f"请输入 `/qa 要问的问题`\n\n"
f"例如已经上传某文档 ,问题是 蘑菇怎么分类?\n\n"
f"先左滑文档进入回复模式,在聊天框里面输入 `/qa 蘑菇怎么分类?`\n\n"
f"问题务必不能有空格,👆点击上方命令复制格式\n\n"
# f"本地知识库目前只支持 Markdown 文件\n\n"
)
await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
return
print("\033[32m", update.effective_user.username, update.effective_user.id, update.message.text, "\033[0m")
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
pdf_file = update.message.reply_to_message.document
question = update.message.text
# 下载文件到本地
file_name = pdf_file.file_name
docpath = os.getcwd() + "/" + file_name
match_embedding = os.path.exists(getmd5(docpath))
if not match_embedding:
file_id = pdf_file.file_id
new_file = context.bot.get_file(file_id)
new_file.download(docpath)
result = await pdfQA(docpath, question)
if not match_embedding:
os.remove(docpath)
context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

async def start(update, context): # 当用户输入/start时,返回文本
user = update.effective_user
Expand Down Expand Up @@ -472,7 +454,7 @@ def setup(token):
application.add_handler(CommandHandler("zh2en", lambda update, context: command_bot(update, context, "english", robot=config.ChatGPTbot)))
application.add_handler(CommandHandler("info", info))
application.add_handler(CommandHandler("qa", qa))
application.add_handler(MessageHandler(filters.Document.MimeType('application/pdf') & filters.TEXT, handle_pdf))
application.add_handler(MessageHandler(filters.Document.MimeType('application/pdf'), handle_pdf))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False)))
application.add_handler(MessageHandler(filters.COMMAND, unknown))
application.add_error_handler(error)
Expand Down

0 comments on commit 5f95832

Please sign in to comment.