Skip to content

Commit

Permalink
minor refactor & cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
KastanDay committed Nov 8, 2023
1 parent e6f7236 commit 9eaf370
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@
import supabase
from bs4 import BeautifulSoup
from git.repo import Repo
from langchain import hub
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.document_loaders import (Docx2txtLoader, GitLoader,
PythonLoader, SRTLoader, TextLoader,
UnstructuredExcelLoader,
UnstructuredPowerPointLoader)
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.document_loaders.image import UnstructuredImageLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.load import dumps, loads
from langchain.schema import Document
from langchain.schema.output_parser import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.qdrant import Qdrant
from pydub import AudioSegment
Expand All @@ -36,11 +41,6 @@
from ai_ta_backend.aws import upload_data_files_to_s3
from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor
from ai_ta_backend.utils_tokenization import count_tokens_and_cost
from langchain import hub
from langchain.llms.openai import OpenAI
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.load import dumps, loads

MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation")

Expand Down Expand Up @@ -989,7 +989,7 @@ def vector_search(self, search_query, course_name):
print("found_docs", found_docs)
return found_docs

def batch_vector_search(self, search_queries: List[str], course_name: str):
def batch_vector_search(self, search_queries: List[str], course_name: str, top_n: int=20):
from qdrant_client.http import models as rest
o = OpenAIEmbeddings()
# Prepare the filter for the course name
Expand All @@ -1006,15 +1006,14 @@ def batch_vector_search(self, search_queries: List[str], course_name: str):
for query in search_queries:
user_query_embedding = o.embed_query(query)
search_requests.append(
rest.SearchRequest(vector=user_query_embedding, filter=myfilter, limit=5, with_payload=True)
rest.SearchRequest(vector=user_query_embedding, filter=myfilter, limit=top_n, with_payload=True)
)

# Perform the batch search
search_results = self.qdrant_client.search_batch(
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
requests=search_requests
)
print(f"Search results: {search_results}")
# Process the search results
found_docs: list[list[Document]] = []
for result in search_results:
Expand Down Expand Up @@ -1133,6 +1132,10 @@ def context_padding(self, found_docs, search_query, course_name):
return result_contexts

def reciprocal_rank_fusion(self, results: list[list], k=60):
"""
Since we have multiple queries, and n documents returned per query, we need to go through all the results
and collect the documents with the highest overall score, as scored by qdrant similarity matching.
"""
fused_scores = {}
for docs in results:
# Assumes the docs are returned in sorted order of relevance
Expand Down Expand Up @@ -1184,26 +1187,13 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
batch_found_docs: list[list[Document]] = self.batch_vector_search(search_queries=generated_queries, course_name=course_name)

found_docs = self.reciprocal_rank_fusion(batch_found_docs)

# LCEL implementation (need custom one for batch vector search since it doesn't implement langchain Runnable)
# retriever = self.vectorstore.as_retriever(
# search_kwargs={"k": top_n, "filter": {"course_name": course_name}}
# )
# chain = generate_queries | retriever.map() | self.reciprocal_rank_fusion
# docs = chain.invoke({"original_query": search_query})

# Only for debugging
# print(f"Docs found with multiple queries: {found_docs}")
found_docs = [doc for doc, score in found_docs]
print(f"Number of docs found with multiple queries: {len(found_docs)}")

if len(found_docs) == 0:
return []

# Extract only the Document objects from the tuples to pass them to context padding
found_docs = [doc for doc, score in found_docs]
print(f"Number of docs found with multiple queries: {len(found_docs)}")

# call context padding function here
# 'context padding' // 'parent document retriever'
# TODO maybe only do context padding for top 5 docs? Otherwise it's wasteful imo.
final_docs = self.context_padding(found_docs, search_query, course_name)
print(f"Number of final docs after context padding: {len(final_docs)}")

Expand All @@ -1228,13 +1218,13 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
break

for v in valid_docs:
print("FINAL VALID DOCS:")
#print("valid doc text: ", v['text'])
print("s3_path: ", v['s3_path'])
print("url: ", v['url'])
print("readable_filename: ", v['readable_filename'])
print("\n")


print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}")
print(f"Course: {course_name} ||| search_query: {search_query}")
print(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds")
Expand Down

0 comments on commit 9eaf370

Please sign in to comment.