Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 13, 2023
1 parent 5fae565 commit a0b7ce1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
2 changes: 1 addition & 1 deletion ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def batch_context_filtering(batch_docs, user_query, max_time_before_return=45, m
"""

start_time = time.monotonic()

partial_func = partial(list_context_filtering, user_query=user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency)
with ProcessPoolExecutor(max_workers=5) as executor:
processed_docs = list(executor.map(partial_func, batch_docs))
Expand Down
65 changes: 46 additions & 19 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from collections import OrderedDict



import ray
import boto3
Expand Down Expand Up @@ -999,6 +1002,9 @@ def vector_search(self, search_query, course_name):
return found_docs

def batch_vector_search(self, search_queries: List[str], course_name: str, top_n: int=50):
"""
Perform a similarity search for all the generated queries at once.
"""
from qdrant_client.http import models as rest
o = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)
# Prepare the filter for the course name
Expand All @@ -1023,50 +1029,71 @@ def batch_vector_search(self, search_queries: List[str], course_name: str, top_n
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
requests=search_requests
)
# Process the search results
# process search results
found_docs: list[list[Document]] = []
for result in search_results:
docs = []
for d in result:
try:
metadata = d.payload
page_content = metadata['page_content']
del metadata['page_content']
if "pagenumber" not in metadata.keys() and "pagenumber_or_timestamp" in metadata.keys(): # type: ignore
# aiding in the database migration...
metadata["pagenumber"] = metadata["pagenumber_or_timestamp"] # type: ignore
docs.append(Document(page_content=page_content, metadata=metadata)) # type: ignore
except Exception as e:
# print(f"Error in batch_vector_search(), for course: `{course_name}`. Error: {e}")
print(traceback.print_exc())
found_docs.append(docs)
docs = []
for doc in result:
try:
metadata = doc.payload
page_content = metadata['page_content']
del metadata['page_content']

if "pagenumber" not in metadata.keys() and "pagenumber_or_timestamp" in metadata.keys():
metadata["pagenumber"] = metadata["pagenumber_or_timestamp"]

docs.append(Document(page_content=page_content, metadata=metadata))
except Exception as e:
print(traceback.print_exc())
found_docs.append(docs)
return found_docs


# Process the search results - remove duplicates and update pagenumber in metadata
# unique_nested_list: list[list[Document]] = []
# for sublist in search_results:
# unique_sublist = OrderedDict()
# for doc in sublist:
# metadata = doc.payload
# page_content = metadata['page_content']
# del metadata['page_content']
# if "pagenumber" not in metadata.keys() and "pagenumber_or_timestamp" in metadata.keys():
# metadata["pagenumber"] = metadata["pagenumber_or_timestamp"]
# doc_key = (metadata.get('readable_filename'), metadata.get('pagenumber'))
# processed_doc = Document(page_content=page_content, metadata=metadata)
# unique_sublist[doc_key] = processed_doc

# unique_nested_list.append(list(unique_sublist.values()))

# return unique_nested_list


def reciprocal_rank_fusion(self, results: list[list], k=60):
"""
Since we have multiple queries, and n documents returned per query, we need to go through all the results
and collect the documents with the highest overall score, as scored by qdrant similarity matching.
"""
fused_scores = {}
count = 0
unique_count = 0
for docs in results:
# Assumes the docs are returned in sorted order of relevance
count += len(docs)
for rank, doc in enumerate(docs):
doc_str = dumps(doc)
if doc_str not in fused_scores:
fused_scores[doc_str] = 0
unique_count += 1
previous_score = fused_scores[doc_str]
fused_scores[doc_str] += 1 / (rank + k)
# Uncomment for debugging
# print(f"Change score for doc: {doc_str}, previous score: {previous_score}, updated score: {fused_scores[doc_str]} ")

#print(f"Change score for doc: {doc_str}, previous score: {previous_score}, updated score: {fused_scores[doc_str]} ")
print(f"Total number of documents in rank fusion: {count}")
print(f"Total number of unique documents in rank fusion: {unique_count}")
reranked_results = [
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]

return reranked_results

def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]:
Expand Down Expand Up @@ -1157,7 +1184,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
print(f"Number of individual docs after context filtering: {filtered_count}")

# if filtered docs are between 0 to 5 (very less), use the pre-filter batch_found_docs
if 0 < filtered_count < 5:
if 0 <= filtered_count <= 5:
found_docs = self.reciprocal_rank_fusion(batch_found_docs)
else:
found_docs = self.reciprocal_rank_fusion(filtered_docs)
Expand Down

0 comments on commit a0b7ce1

Please sign in to comment.