Skip to content

Commit

Permalink
fixed bug: doc qa reach openai maximum context length
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 18, 2023
1 parent 13f4fb8 commit 3610fd9
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 88 deletions.
175 changes: 93 additions & 82 deletions agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import config
import chardet
import asyncio
import tiktoken
import requests
import threading
Expand All @@ -13,7 +14,7 @@
from bs4 import BeautifulSoup

from langchain.llms import OpenAI
from langchain.chains import LLMChain, RetrievalQA
from langchain.chains import LLMChain, RetrievalQA, RetrievalQAWithSourcesChain
from langchain.agents import AgentType, load_tools, initialize_agent, tool
from langchain.schema import HumanMessage
from langchain.schema.output import LLMResult
Expand Down Expand Up @@ -57,6 +58,41 @@ async def get_doc_from_local(docpath, doctype="md"):
documents = loader.load()
return documents

system_template="""Use the following pieces of context to answer the users question.
If you don't know the answer, just say "Hmm..., I'm not sure.", don't try to make up an answer.
ALWAYS return a "Sources" part in your answer.
The "Sources" part should be a reference to the source of the document from which you got your answer.
Example of your response should be:
```
The answer is foo
Sources:
1. abc
2. xyz
```
Begin!
----------------
{summaries}
"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)

def get_chain(store, llm):
chain_type_kwargs = {"prompt": prompt}
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm,
chain_type="stuff",
retriever=store.as_retriever(),
chain_type_kwargs=chain_type_kwargs,
reduce_k_below_max_tokens=True
)
return chain

async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-turbo"):
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=config.API)
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=config.API)
Expand All @@ -74,69 +110,21 @@ async def docQA(docpath, query_message, persist_db_path="db", model = "gpt-3.5-t
documents = await doc_method(docpath)
# 初始化加载器
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50)
# 切割加载的 document
split_docs = text_splitter.split_documents(documents)
# 持久化数据
split_docs = text_splitter.split_documents(documents)
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
vector_store.persist()
else:
# 加载数据
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)

# 创建问答对象
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
qa = get_chain(vector_store, chatllm)
# qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(), return_source_documents=True)
# 进行问答
result = qa({"query": query_message})
result = qa({"question": query_message})
return result

@tool
def time(text: str) -> str:
"""Returns todays date, use this for any \
questions related to knowing todays date. \
The input should always be an empty string, \
and this function will always return todays \
date - any date mathmatics should occur \
outside this function."""
return str(date.today())

def today_date():
return str(date.today())

def duckduckgo_search(result, model="gpt-3.5-turbo", temperature=0.5):
try:
translate_prompt = PromptTemplate(
input_variables=["targetlang", "text"],
template="You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text to {targetlang}, please do not explain any sentences, just translate or leave them as they are.: {text}",
)
chatllm = ChatOpenAI(temperature=temperature, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=config.API)

# # 翻译成英文 带聊天模型的链 方法一
# translate_template="You are a translation engine, you can only translate text and cannot interpret it, and do not explain. Translate the text from {sourcelang} to {targetlang}, please do not explain any sentences, just translate or leave them as they are."
# system_message_prompt = SystemMessagePromptTemplate.from_template(translate_template)
# human_template = "{text}"
# human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
# chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
# zh2enchain = LLMChain(llm=chatllm, prompt=chat_prompt)
# result = zh2enchain.run(sourcelang="Simplified Chinese", targetlang="English", text=searchtext)

# # 翻译成英文 方法二
# chain = LLMChain(llm=chatllm, prompt=translate_prompt)
# result = chain.run({"targetlang": "english", "text": searchtext})

# 搜索
tools = load_tools(["ddg-search", "llm-math", "wikipedia"], llm=chatllm)
agent = initialize_agent(tools + [time], chatllm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, max_iterations=2, early_stopping_method="generate", handle_parsing_errors=True)
result = agent.run(result)

# 翻译成中文
en2zhchain = LLMChain(llm=chatllm, prompt=translate_prompt)
result = en2zhchain.run(targetlang="Simplified Chinese", text=result)
result = en2zhchain.run({"targetlang": "Simplified Chinese", "text": result})

return result
except Exception as e:
traceback.print_exc()

def get_doc_from_url(url):
filename = url.split("/")[-1]
response = requests.get(url, stream=True)
Expand All @@ -145,10 +133,33 @@ def get_doc_from_url(url):
f.write(chunk)
return filename

def pdf_search(docpath, query_message, model="gpt-3.5-turbo"):
def persist_emdedding_pdf(docpath, persist_db_path, embeddings):
loader = UnstructuredPDFLoader(docpath)
documents = loader.load()
# 初始化加载器
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=25)
# 切割加载的 document
split_docs = text_splitter.split_documents(documents)
vector_store = Chroma.from_documents(split_docs, embeddings, persist_directory=persist_db_path)
vector_store.persist()
return vector_store

def pdfQA(docpath, query_message, model="gpt-3.5-turbo"):
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None))
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=os.environ.get('API', None))
persist_db_path = getmd5(docpath)
if not os.path.exists(persist_db_path):
vector_store = persist_emdedding_pdf(docpath, persist_db_path, embeddings)
else:
vector_store = Chroma(persist_directory=persist_db_path, embedding_function=embeddings)
qa = RetrievalQA.from_chain_type(llm=chatllm, chain_type="stuff", retriever=vector_store.as_retriever(),return_source_documents=True)
result = qa({"query": query_message})
return result['result']

def pdf_search(docurl, query_message, model="gpt-3.5-turbo"):
chatllm = ChatOpenAI(temperature=0.5, openai_api_base=os.environ.get('API_URL', None).split("chat")[0], model_name=model, openai_api_key=os.environ.get('API', None))
embeddings = OpenAIEmbeddings(openai_api_base=os.environ.get('API_URL', None).split("chat")[0], openai_api_key=os.environ.get('API', None))
filename = get_doc_from_url(docpath)
filename = get_doc_from_url(docurl)
docpath = os.getcwd() + "/" + filename
loader = UnstructuredPDFLoader(docpath)
documents = loader.load()
Expand Down Expand Up @@ -389,34 +400,34 @@ def search_summary(result, model=config.DEFAULT_SEARCH_MODEL, temperature=config
# from langchain.agents import get_all_tool_names
# print(get_all_tool_names())

# 搜索
# print(duckduckgo_search("凡凡还有多久出狱?"))
# print(search_summary("凡凡还有多久出狱?"))

# for i in search_summary("今天的微博热搜有哪些?"):
# for i in search_summary("用python写个网络爬虫给我"):
# for i in search_summary("消失的她主要讲了什么?"):
# for i in search_summary("奥巴马的全名是什么?"):
# for i in search_summary("华为mate60怎么样?"):
# for i in search_summary("慈禧养的猫叫什么名字?"):
# for i in search_summary("民进党当初为什么支持柯文哲选台北市长?"):
# for i in search_summary("Has the United States won the china US trade war?"):
# for i in search_summary("What does 'n+2' mean in Huawei's 'Mate 60 Pro' chipset? Please conduct in-depth analysis."):
# for i in search_summary("AUTOMATIC1111 是什么?"):
for i in search_summary("python telegram bot 怎么接收pdf文件"):
# for i in search_summary("中国利用外资指标下降了 87% ?真的假的。"):
# for i in search_summary("How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"):
# for i in search_summary("英国脱欧没有好处,为什么英国人还是要脱欧?"):
# for i in search_summary("2022年俄乌战争为什么发生?"):
# for i in search_summary("卡罗尔与星期二讲的啥?"):
# for i in search_summary("金砖国家会议有哪些决定?"):
# for i in search_summary("iphone15有哪些新功能?"):
# for i in search_summary("python函数开头:def time(text: str) -> str:每个部分有什么用?"):
print(i, end="")

# # 问答
# # 搜索

# # for i in search_summary("今天的微博热搜有哪些?"):
# # for i in search_summary("用python写个网络爬虫给我"):
# # for i in search_summary("消失的她主要讲了什么?"):
# # for i in search_summary("奥巴马的全名是什么?"):
# # for i in search_summary("华为mate60怎么样?"):
# # for i in search_summary("慈禧养的猫叫什么名字?"):
# # for i in search_summary("民进党当初为什么支持柯文哲选台北市长?"):
# # for i in search_summary("Has the United States won the china US trade war?"):
# # for i in search_summary("What does 'n+2' mean in Huawei's 'Mate 60 Pro' chipset? Please conduct in-depth analysis."):
# # for i in search_summary("AUTOMATIC1111 是什么?"):
# for i in search_summary("python telegram bot 怎么接收pdf文件"):
# # for i in search_summary("中国利用外资指标下降了 87% ?真的假的。"):
# # for i in search_summary("How much does the 'zeabur' software service cost per month? Is it free to use? Any limitations?"):
# # for i in search_summary("英国脱欧没有好处,为什么英国人还是要脱欧?"):
# # for i in search_summary("2022年俄乌战争为什么发生?"):
# # for i in search_summary("卡罗尔与星期二讲的啥?"):
# # for i in search_summary("金砖国家会议有哪些决定?"):
# # for i in search_summary("iphone15有哪些新功能?"):
# # for i in search_summary("python函数开头:def time(text: str) -> str:每个部分有什么用?"):
# print(i, end="")

# 问答
# result = asyncio.run(docQA("/Users/yanyuming/Downloads/GitHub/wiki/docs", "ubuntu 版本号怎么看?"))
# # result = asyncio.run(docQA("https://yym68686.top", "reid可以怎么分类?"))
# result = asyncio.run(docQA("https://yym68686.top", "说一下HSTL pipeline"))
result = asyncio.run(docQA("https://wiki.yym68686.top", "PyTorch to MindSpore翻译思路是什么?"))
print(result['answer'])
# source_url = set([i.metadata['source'] for i in result["source_documents"]])
# source_url = "\n".join(source_url)
# message = (
Expand Down
27 changes: 21 additions & 6 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ async def getChatGPT(title, robot, message, update, context):
print("error", e)
print('\033[0m')
if config.API:
robot.reset(convo_id=str(update.message.chat_id), system_prompt=systemprompt)
robot.reset(convo_id=str(update.message.chat_id), system_prompt=config.systemprompt)
if "You exceeded your current quota, please check your plan and billing details." in str(e):
print("OpenAI api 已过期!")
await context.bot.delete_message(chat_id=update.message.chat_id, message_id=messageid)
Expand Down Expand Up @@ -288,6 +288,25 @@ async def search(update, context, has_command=True):
reply_to_message_id=update.message.message_id,
)

from agent import pdfQA, getmd5
def handle_pdf(update, context):
# 获取接收到的文件
pdf_file = update.message.document
question = update.message.text

# 下载文件到本地
file_name = pdf_file.file_name
docpath = os.getcwd() + "/" + file_name
match_embedding = os.path.exists(getmd5(docpath))
if not match_embedding:
file_id = pdf_file.file_id
new_file = context.bot.get_file(file_id)
new_file.download(docpath)
result = pdfQA(docpath, question)
if not match_embedding:
os.remove(docpath)
context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

async def qa(update, context):
if (len(context.args) != 2):
message = (
Expand Down Expand Up @@ -338,9 +357,6 @@ def setup(token):
application = ApplicationBuilder().read_timeout(10).connection_pool_size(50000).pool_timeout(1200.0).token(token).build()

run_async(application.bot.set_my_commands([
# BotCommand('gpt4', 'use gpt4'),
# BotCommand('claude2', 'use claude2'),
# BotCommand('search', 'search the web with google and duckduckgo'),
BotCommand('info', 'basic information'),
BotCommand('gpt_use_search', 'open or close gpt use search'),
BotCommand('history', 'open or close chat history'),
Expand All @@ -357,13 +373,12 @@ def setup(token):
application.add_handler(CommandHandler("reset", reset_chat))
application.add_handler(CommandHandler("en2zh", lambda update, context: command_bot(update, context, "simplified chinese", robot=config.ChatGPTbot)))
application.add_handler(CommandHandler("zh2en", lambda update, context: command_bot(update, context, "english", robot=config.ChatGPTbot)))
# application.add_handler(CommandHandler("gpt4", lambda update, context: command_bot(update, context, prompt=None, title="`🤖️ gpt-4`\n\n", robot=config.ChatGPT4bot)))
# application.add_handler(CommandHandler("claude2", lambda update, context: command_bot(update, context, prompt=None, title="`🤖️ claude2`\n\n", robot=config.Claude2bot)))
application.add_handler(CommandHandler("info", info))
application.add_handler(CommandHandler("history", history))
application.add_handler(CommandHandler("google", google))
application.add_handler(CommandHandler("gpt_use_search", gpt_use_search))
application.add_handler(CommandHandler("qa", qa))
application.add_handler(MessageHandler(filters.Document.mime_type('application/pdf') & filters.TEXT, handle_pdf))
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, lambda update, context: command_bot(update, context, prompt=None, title=f"`🤖️ {config.GPT_ENGINE}`\n\n", robot=config.ChatGPTbot, has_command=False)))
application.add_handler(MessageHandler(filters.COMMAND, unknown))
application.add_error_handler(error)
Expand Down

0 comments on commit 3610fd9

Please sign in to comment.