From fcef0adf57b7ef9ae347181d344249bdc94eb134 Mon Sep 17 00:00:00 2001 From: AquibPy Date: Fri, 29 Mar 2024 02:51:19 +0530 Subject: [PATCH] ADDED: RAG using Groq --- api.py | 20 +++++++++++++++++++- helper_functions.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/api.py b/api.py index b652743..1e924b9 100644 --- a/api.py +++ b/api.py @@ -9,7 +9,7 @@ from settings import invoice_prompt,youtube_transcribe_prompt,text2sql_prompt,EMPLOYEE_DB from mongo import MongoDB from helper_functions import get_qa_chain,get_gemini_response,get_url_doc_qa,extract_transcript_details,\ - get_gemini_response_health,get_gemini_pdf,read_sql_query,remove_substrings,questions_generator + get_gemini_response_health,get_gemini_pdf,read_sql_query,remove_substrings,questions_generator,groq_pdf from langchain_groq import ChatGroq from langchain.chains import ConversationChain from langchain.chains.conversation.memory import ConversationBufferWindowMemory @@ -286,5 +286,23 @@ async def groq_text_summary(input_text: str = Form(...)): mongo_data = {"Document": payload} result = db.insert_data(mongo_data) return {"Summary": summary_text} + except Exception as e: + return ResponseText(response=f"Error: {str(e)}") + +@app.post("/RAG_PDF_Groq",description="The endpoint uses the pdf and give the answer based on the prompt provided using groq") +async def talk_pd_groq(pdf: UploadFile = File(...),prompt: str = Form(...)): + try: + rag_chain = groq_pdf(pdf.file) + out = rag_chain.invoke(prompt) + db = MongoDB() + payload = { + "endpoint" : "/RAG_PDF_Groq", + "prompt" : prompt, + "output" : out + } + mongo_data = {"Document": payload} + result = db.insert_data(mongo_data) + print(result) + return ResponseText(response=out) except Exception as e: return ResponseText(response=f"Error: {str(e)}") \ No newline at end of file diff --git a/helper_functions.py b/helper_functions.py index 7388098..43b960c 100644 --- a/helper_functions.py +++ b/helper_functions.py @@ -10,6 +10,10 @@ from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain.chains.summarize import load_summarize_chain +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_core.output_parsers import StrOutputParser +from langchain_groq import ChatGroq from dotenv import load_dotenv import google.generativeai as genai from youtube_transcript_api import YouTubeTranscriptApi @@ -61,7 +65,7 @@ def create_vector_db(): def get_qa_chain(): llm = GoogleGenerativeAI(model=PALM_MODEL, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.7) - vectordb = FAISS.load_local(VECTORDB_PATH,PaLM_embeddings) + vectordb = FAISS.load_local(VECTORDB_PATH,PaLM_embeddings,allow_dangerous_deserialization=True) retriever = vectordb.as_retriever(score_threshold=0.7) PROMPT = PromptTemplate( template=qa_prompt, input_variables=["context", "question"] @@ -184,6 +188,30 @@ def questions_generator(doc): ques = ques_gen_chain.run(document_ques_gen) return ques +def groq_pdf(pdf): + llm = ChatGroq( + api_key=os.environ['GROQ_API_KEY'], + model_name='mixtral-8x7b-32768' + ) + text = "".join(page.extract_text() for page in PdfReader(pdf).pages) + text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000) + chunks = text_splitter.split_text(text) + embeddings = GoogleGenerativeAIEmbeddings(model = "models/embedding-001") + vectorstore = FAISS.from_texts(chunks, embedding=embeddings) + retriever = vectorstore.as_retriever() + rag_template = """Answer the question based only on the following context: + {context} + Question: {question} + """ + rag_prompt = ChatPromptTemplate.from_template(rag_template) + rag_chain = ( + {"context": retriever, "question": RunnablePassthrough()} + | rag_prompt + | llm + | StrOutputParser() + ) + return rag_chain + if __name__ == "__main__": create_vector_db() chain = get_qa_chain()