Skip to content

Commit

Permalink
ADDED: Authentication on get_data route and token generation
Browse files Browse the repository at this point in the history
  • Loading branch information
AquibPy committed May 13, 2024
1 parent 20c818f commit cdad173
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 39 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ jobs:
echo "MONGO_PASSWORD=${{ secrets.MONGO_PASSWORD }}" >> $GITHUB_ENV
echo "MONGO_DBNAME=${{ secrets.MONGO_DBNAME }}" >> $GITHUB_ENV
echo "MONGO_COLLECTION=${{ secrets.MONGO_COLLECTION }}" >> $GITHUB_ENV
echo "MONGO_COLLECTION_USER=${{ secrets.MONGO_COLLECTION_USER }}" >> $GITHUB_ENV
echo "LANGCHAIN_API_KEY=${{ secrets.LANGCHAIN_API_KEY }}" >> $GITHUB_ENV
echo "REDIS_HOST=${{ secrets.REDIS_HOST }}" >> $GITHUB_ENV
echo "REDIS_PASSWORD=${{ secrets.REDIS_PASSWORD }}" >> $GITHUB_ENV
echo "TOKEN_SECRET_KEY=${{ secrets.TOKEN_SECRET_KEY }}" >> $GITHUB_ENV
- name: Install dependencies
run: |
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ __pycache__/
venv/
.env
.pytest_cache/
.vscode
.vscode
hit.py
85 changes: 66 additions & 19 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
from PIL import Image
from redis import Redis
import json
from fastapi import FastAPI,Form,File,UploadFile, Request ,Response
from fastapi import FastAPI,Form,File,UploadFile, Request ,Response, HTTPException, status, Depends
from fastapi.templating import Jinja2Templates
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse,RedirectResponse,StreamingResponse
from typing import List,Optional
from pydantic import BaseModel
import google.generativeai as genai
from fastapi.middleware.cors import CORSMiddleware
from settings import invoice_prompt,youtube_transcribe_prompt,text2sql_prompt,EMPLOYEE_DB,GEMINI_PRO,GEMINI_PRO_1_5, diffusion_models, REDIS_PORT
from mongo import MongoDB
from helper_functions import get_qa_chain,get_gemini_response,get_url_doc_qa,extract_transcript_details,\
get_gemini_response_health,get_gemini_pdf,read_sql_query,remove_substrings,questions_generator,groq_pdf,\
Expand All @@ -22,13 +20,24 @@
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain_core.prompts import ChatPromptTemplate
from auth import create_access_token
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from datetime import timedelta
from jose import jwt, JWTError
import settings
from models import UserCreate, ResponseText


os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_PROJECT"]="genify"
os.environ["LANGCHAIN_ENDPOINT"]="https://api.smith.langchain.com"

redis = Redis(host=os.getenv("REDIS_HOST"), port=REDIS_PORT, password=os.getenv("REDIS_PASSWORD"))
redis = Redis(host=os.getenv("REDIS_HOST"), port=settings.REDIS_PORT, password=os.getenv("REDIS_PASSWORD"))

mongo_client = MongoDB(collection_name=os.getenv("MONGO_COLLECTION_USER"))
users_collection = mongo_client.collection
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

app = FastAPI(title="Genify By Mohd Aquib",
summary="This API contains routes of different Gen AI usecases")
Expand All @@ -45,14 +54,42 @@
allow_headers=["*"],
)

class ResponseText(BaseModel):
response: str


@app.get("/", response_class=RedirectResponse)
async def home():
return RedirectResponse("/docs")

@app.post("/signup")
async def signup(user: UserCreate):
# Check if user already exists
existing_user = users_collection.find_one({"email": user.email})
if existing_user:
raise HTTPException(status_code=400, detail="Email already registered")

# Insert new user to database
user_dict = user.model_dump()
users_collection.insert_one(user_dict)

return {"message": "User created successfully"}

# Signin route
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
# Check if user exists in database
user = users_collection.find_one({"email": form_data.username})
if not user or user["password"] != form_data.password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
headers={"WWW-Authenticate": "Bearer"},
)

# Create access token
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["email"]}, expires_delta=access_token_expires
)

return {"access_token": access_token, "token_type": "bearer"}

@app.get("/chatbot",description="Provides a simple web interface to interact with the chatbot")
async def chat(request: Request):
Expand All @@ -73,7 +110,7 @@ async def gemini(image_file: UploadFile = File(...), prompt: str = Form(...)):
"mime_type": "image/jpeg",
"data": image
}]
output = get_gemini_response(invoice_prompt, image_parts, prompt)
output = get_gemini_response(settings.invoice_prompt, image_parts, prompt)
db = MongoDB()
payload = {
"endpoint" : "/invoice_extractor",
Expand Down Expand Up @@ -155,9 +192,9 @@ async def youtube_video_transcribe_summarizer_gemini(url: str = Form(...)):
print("Retrieving response from Redis cache")
return ResponseText(response=cached_response.decode("utf-8"))

model = genai.GenerativeModel(GEMINI_PRO)
model = genai.GenerativeModel(settings.GEMINI_PRO)
transcript_text = extract_transcript_details(url)
response = model.generate_content(youtube_transcribe_prompt + transcript_text)
response = model.generate_content(settings.youtube_transcribe_prompt + transcript_text)
redis.set(cache_key, response.text, ex=60)
db = MongoDB()
payload = {
Expand Down Expand Up @@ -213,7 +250,7 @@ async def blogs(topic: str = Form("Generative AI")):
print("Retrieving response from Redis cache")
return ResponseText(response=cached_response.decode("utf-8"))

model = genai.GenerativeModel(GEMINI_PRO_1_5)
model = genai.GenerativeModel(settings.GEMINI_PRO_1_5)
blog_prompt = f""" You are expert in blog writing. Write a blog on the topic {topic}. Use a friendly and informative tone, and include examples and tips to encourage readers to get started with the topic provided. """
response = model.generate_content(blog_prompt)
redis.set(cache_key, response.text, ex=60)
Expand Down Expand Up @@ -260,11 +297,11 @@ async def sql_query(prompt: str = Form("Tell me the employees living in city Noi
cached_data = json.loads(cached_response)
return cached_data

model = genai.GenerativeModel(GEMINI_PRO_1_5)
response = model.generate_content([text2sql_prompt, prompt])
model = genai.GenerativeModel(settings.GEMINI_PRO_1_5)
response = model.generate_content([settings.text2sql_prompt, prompt])
output_query = remove_substrings(response.text)
print(output_query)
output = read_sql_query(remove_substrings(output_query), EMPLOYEE_DB)
output = read_sql_query(remove_substrings(output_query), settings.EMPLOYEE_DB)
cached_data = {"response": {"SQL Query": output_query, "Data": output}}
redis.set(cache_key, json.dumps(cached_data), ex=60)
db = MongoDB()
Expand Down Expand Up @@ -422,7 +459,7 @@ async def ats(resume_pdf: UploadFile = File(...), job_description: str = Form(..
return ResponseText(response=cached_response.decode("utf-8"))

text = extraxt_pdf_text(resume_pdf.file)
model = genai.GenerativeModel(GEMINI_PRO_1_5)
model = genai.GenerativeModel(settings.GEMINI_PRO_1_5)
ats_prompt = f"""
Hey Act Like a skilled or very experienced ATS (Application Tracking System)
with a deep understanding of the tech field, software engineering, data science, data analysis,
Expand Down Expand Up @@ -472,11 +509,11 @@ async def ats(resume_pdf: UploadFile = File(...), job_description: str = Form(..
""")
def generate_image(prompt: str = Form("Astronaut riding a horse"), model: str = Form("Stable_Diffusion_base")):
try:
if model in diffusion_models:
if model in settings.diffusion_models:
def query(payload):
api_key = os.getenv("HUGGINGFACE_API_KEY")
headers = {"Authorization": f"Bearer {api_key}"}
response = requests.post(diffusion_models[model], headers=headers, json=payload)
response = requests.post(settings.diffusion_models[model], headers=headers, json=payload)
return response.content

image_bytes = query({"inputs": prompt})
Expand All @@ -494,7 +531,17 @@ def query(payload):
return ResponseText(response="Busy server: Please try later")

@app.get("/get_data/{endpoint_name}")
async def get_data(endpoint_name: str):
async def get_data(endpoint_name: str, token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, os.getenv("TOKEN_SECRET_KEY"), algorithms=[settings.ALGORITHM])
email = payload.get("sub")
if email is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
user = users_collection.find_one({"email": email})
if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
cache_key = f"{endpoint_name}"
cached_data = redis.get(cache_key)

Expand Down
16 changes: 16 additions & 0 deletions auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Optional
from datetime import datetime, timedelta
import settings
from jose import jwt
import os

# Helper function to create access token
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, os.getenv("TOKEN_SECRET_KEY"), algorithm=settings.ALGORITHM)
return encoded_jwt
29 changes: 14 additions & 15 deletions helper_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from settings import GOOGLE_EMBEDDING,FAQ_FILE,INSTRUCTOR_EMBEDDING,VECTORDB_PATH,qa_prompt,\
prompt_pdf,question_prompt_template,question_refine_template, GEMINI_PRO
import settings
from langchain_google_genai import GoogleGenerativeAI,GoogleGenerativeAIEmbeddings,ChatGoogleGenerativeAI
from langchain_community.document_loaders import CSVLoader
from langchain_community.document_loaders import UnstructuredURLLoader,PyPDFLoader,WebBaseLoader
Expand Down Expand Up @@ -32,15 +31,15 @@

PaLM_embeddings = GooglePalmEmbeddings(google_api_key=os.getenv("GOOGLE_API_KEY"))

google_embedding = GoogleGenerativeAIEmbeddings(model = GOOGLE_EMBEDDING)
google_embedding = GoogleGenerativeAIEmbeddings(model = settings.GOOGLE_EMBEDDING)

'''
if you want you can try instructor embeddings also. Below is thge code :
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key=os.getenv("HUGGINGFACE_API_KEY"), model_name=INSTRUCTOR_EMBEDDING,query_instruction="Represent the query for retrieval: "
api_key=os.getenv("HUGGINGFACE_API_KEY"), model_name=settings.INSTRUCTOR_EMBEDDING,query_instruction="Represent the query for retrieval: "
)
'''

Expand All @@ -64,17 +63,17 @@ def get_gemini_response_health(image_file, prompt):
return f"Error: {str(e)}"

def create_vector_db():
loader = CSVLoader(file_path=FAQ_FILE)
loader = CSVLoader(file_path=settings.FAQ_FILE)
data = loader.load()
vectordb = FAISS.from_documents(documents = data,embedding=PaLM_embeddings)
vectordb.save_local(VECTORDB_PATH)
vectordb.save_local(settings.VECTORDB_PATH)

def get_qa_chain():
llm = GoogleGenerativeAI(model= GEMINI_PRO, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.7)
vectordb = FAISS.load_local(VECTORDB_PATH,PaLM_embeddings,allow_dangerous_deserialization=True)
llm = GoogleGenerativeAI(model= settings.GEMINI_PRO, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.7)
vectordb = FAISS.load_local(settings.VECTORDB_PATH,PaLM_embeddings,allow_dangerous_deserialization=True)
retriever = vectordb.as_retriever(score_threshold=0.7)
PROMPT = PromptTemplate(
template=qa_prompt, input_variables=["context", "question"]
template=settings.qa_prompt, input_variables=["context", "question"]
)

chain = RetrievalQA.from_chain_type(llm=llm,
Expand All @@ -87,7 +86,7 @@ def get_qa_chain():
return chain

def get_url_doc_qa(url,doc):
llm = GoogleGenerativeAI(model= GEMINI_PRO, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.3)
llm = GoogleGenerativeAI(model= settings.GEMINI_PRO, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.3)
if url:
loader = WebBaseLoader(url)
data = loader.load()
Expand Down Expand Up @@ -129,10 +128,10 @@ def get_gemini_pdf(pdf):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=10000, chunk_overlap=1000)
chunks = text_splitter.split_text(text)
vector_store = FAISS.from_texts(chunks, embedding=google_embedding)
llm = GoogleGenerativeAI(model= GEMINI_PRO, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.7)
llm = GoogleGenerativeAI(model= settings.GEMINI_PRO, google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.7)
retriever = vector_store.as_retriever(score_threshold=0.7)
PROMPT = PromptTemplate(
template=prompt_pdf, input_variables=["context", "question"]
template=settings.prompt_pdf, input_variables=["context", "question"]
)

chain = RetrievalQA.from_chain_type(llm=llm,
Expand Down Expand Up @@ -182,9 +181,9 @@ def questions_generator(doc):
# splitter_ans_gen = TokenTextSplitter(chunk_size = 1000,chunk_overlap = 100)
# document_answer_gen = splitter_ans_gen.split_documents(document_ques_gen)

llm_ques_gen_pipeline = ChatGoogleGenerativeAI(model= GEMINI_PRO,google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.3)
PROMPT_QUESTIONS = PromptTemplate(template=question_prompt_template, input_variables=["text"])
REFINE_PROMPT_QUESTIONS = PromptTemplate(input_variables=["existing_answer", "text"],template=question_refine_template)
llm_ques_gen_pipeline = ChatGoogleGenerativeAI(model= settings.GEMINI_PRO,google_api_key=os.getenv("GOOGLE_API_KEY"),temperature=0.3)
PROMPT_QUESTIONS = PromptTemplate(template=settings.question_prompt_template, input_variables=["text"])
REFINE_PROMPT_QUESTIONS = PromptTemplate(input_variables=["existing_answer", "text"],template=settings.question_refine_template)
ques_gen_chain = load_summarize_chain(llm = llm_ques_gen_pipeline,
chain_type = "refine",
verbose = False,
Expand Down
13 changes: 13 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field, EmailStr

class ResponseText(BaseModel):
response: str

class UserCreate(BaseModel):
email: EmailStr = Field(...)
password: str = Field(..., min_length=8)

# User model for database
class UserInDB(BaseModel):
email: EmailStr
password: str
6 changes: 3 additions & 3 deletions mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import os

class MongoDB():
def __init__(self):
def __init__(self, dbname=None, collection_name=None):
self.username = os.getenv("MONGO_USERNAME")
self.password = os.getenv("MONGO_PASSWORD")
self.dbname = os.getenv("MONGO_DBNAME")
self.collection_name = os.getenv("MONGO_COLLECTION")
self.dbname = os.getenv("MONGO_DBNAME") if dbname is None else dbname
self.collection_name = os.getenv("MONGO_COLLECTION") if collection_name is None else collection_name
try:
self.client = MongoClient(f"mongodb+srv://{self.username}:{self.password}@cluster0.sdx7i86.mongodb.net/{self.dbname}") ## i am using mongodb atlas. Use can use your local mongodb
database = self.dbname
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ groq
langchain-groq
jinja2
tiktoken
redis
redis
python-jose
3 changes: 3 additions & 0 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
INSTRUCTOR_EMBEDDING = "sentence-transformers/all-MiniLM-l6-v2"
VECTORDB_PATH = "faiss_index"
REDIS_PORT = 19061
# SECRET_KEY = "dOZfxDmHFSmrRlTTNcW0IlsfCkxEJ7-8x4xYFs_WQnE"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
qa_prompt = """
Given the following context and a question, generate an answer based on this context only.
In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
Expand Down

0 comments on commit cdad173

Please sign in to comment.