Skip to content

Commit

Permalink
ADDED: RAG using Groq
Browse files Browse the repository at this point in the history
  • Loading branch information
AquibPy committed Mar 28, 2024
1 parent 40ca959 commit fcef0ad
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
20 changes: 19 additions & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from settings import invoice_prompt,youtube_transcribe_prompt,text2sql_prompt,EMPLOYEE_DB
from mongo import MongoDB
from helper_functions import get_qa_chain,get_gemini_response,get_url_doc_qa,extract_transcript_details,\
get_gemini_response_health,get_gemini_pdf,read_sql_query,remove_substrings,questions_generator
get_gemini_response_health,get_gemini_pdf,read_sql_query,remove_substrings,questions_generator,groq_pdf
from langchain_groq import ChatGroq
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
Expand Down Expand Up @@ -286,5 +286,23 @@ async def groq_text_summary(input_text: str = Form(...)):
mongo_data = {"Document": payload}
result = db.insert_data(mongo_data)
return {"Summary": summary_text}
except Exception as e:
return ResponseText(response=f"Error: {str(e)}")

@app.post("/RAG_PDF_Groq",description="The endpoint uses the pdf and give the answer based on the prompt provided using groq")
async def talk_pd_groq(pdf: UploadFile = File(...),prompt: str = Form(...)):
try:
rag_chain = groq_pdf(pdf.file)
out = rag_chain.invoke(prompt)
db = MongoDB()
payload = {
"endpoint" : "/RAG_PDF_Groq",
"prompt" : prompt,
"output" : out
}
mongo_data = {"Document": payload}
result = db.insert_data(mongo_data)
print(result)
return ResponseText(response=out)
except Exception as e:
return ResponseText(response=f"Error: {str(e)}")
30 changes: 29 additions & 1 deletion helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.chains.summarize import load_summarize_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import google.generativeai as genai
from youtube_transcript_api import YouTubeTranscriptApi
Expand Down Expand Up @@ -61,7 +65,7 @@ def create_vector_db():

def get_qa_chain():
llm = GoogleGenerativeAI(model=PALM_MODEL, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.7)
vectordb = FAISS.load_local(VECTORDB_PATH,PaLM_embeddings)
vectordb = FAISS.load_local(VECTORDB_PATH,PaLM_embeddings,allow_dangerous_deserialization=True)
retriever = vectordb.as_retriever(score_threshold=0.7)
PROMPT = PromptTemplate(
template=qa_prompt, input_variables=["context", "question"]
Expand Down Expand Up @@ -184,6 +188,30 @@ def questions_generator(doc):
ques = ques_gen_chain.run(document_ques_gen)
return ques

def groq_pdf(pdf):
llm = ChatGroq(
api_key=os.environ['GROQ_API_KEY'],
model_name='mixtral-8x7b-32768'
)
text = "".join(page.extract_text() for page in PdfReader(pdf).pages)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
chunks = text_splitter.split_text(text)
embeddings = GoogleGenerativeAIEmbeddings(model = "models/embedding-001")
vectorstore = FAISS.from_texts(chunks, embedding=embeddings)
retriever = vectorstore.as_retriever()
rag_template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
rag_prompt = ChatPromptTemplate.from_template(rag_template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
return rag_chain

if __name__ == "__main__":
create_vector_db()
chain = get_qa_chain()
Expand Down

0 comments on commit fcef0ad

Please sign in to comment.