diff --git a/ai_ta_backend/filtering_contexts.py b/ai_ta_backend/filtering_contexts.py index cfc8c95b..c70ff29d 100644 --- a/ai_ta_backend/filtering_contexts.py +++ b/ai_ta_backend/filtering_contexts.py @@ -69,7 +69,7 @@ def batch_context_filtering(batch_docs, user_query, max_time_before_return=45, m """ start_time = time.monotonic() - + partial_func = partial(list_context_filtering, user_query=user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency) with ProcessPoolExecutor(max_workers=5) as executor: processed_docs = list(executor.map(partial_func, batch_docs)) diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index c87c43f6..ce7c89d7 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -12,6 +12,9 @@ from pathlib import Path from tempfile import NamedTemporaryFile from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from collections import OrderedDict + + import ray import boto3 @@ -999,6 +1002,9 @@ def vector_search(self, search_query, course_name): return found_docs def batch_vector_search(self, search_queries: List[str], course_name: str, top_n: int=50): + """ + Perform a similarity search for all the generated queries at once. + """ from qdrant_client.http import models as rest o = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # Prepare the filter for the course name @@ -1023,26 +1029,45 @@ def batch_vector_search(self, search_queries: List[str], course_name: str, top_n collection_name=os.environ['QDRANT_COLLECTION_NAME'], requests=search_requests ) - # Process the search results + # process search results found_docs: list[list[Document]] = [] for result in search_results: - docs = [] - for d in result: - try: - metadata = d.payload - page_content = metadata['page_content'] - del metadata['page_content'] - if "pagenumber" not in metadata.keys() and "pagenumber_or_timestamp" in metadata.keys(): # type: ignore - # aiding in the database migration... - metadata["pagenumber"] = metadata["pagenumber_or_timestamp"] # type: ignore - docs.append(Document(page_content=page_content, metadata=metadata)) # type: ignore - except Exception as e: - # print(f"Error in batch_vector_search(), for course: `{course_name}`. Error: {e}") - print(traceback.print_exc()) - found_docs.append(docs) + docs = [] + for doc in result: + try: + metadata = doc.payload + page_content = metadata['page_content'] + del metadata['page_content'] + + if "pagenumber" not in metadata.keys() and "pagenumber_or_timestamp" in metadata.keys(): + metadata["pagenumber"] = metadata["pagenumber_or_timestamp"] + docs.append(Document(page_content=page_content, metadata=metadata)) + except Exception as e: + print(traceback.print_exc()) + found_docs.append(docs) return found_docs + + # Process the search results - remove duplicates and update pagenumber in metadata + # unique_nested_list: list[list[Document]] = [] + # for sublist in search_results: + # unique_sublist = OrderedDict() + # for doc in sublist: + # metadata = doc.payload + # page_content = metadata['page_content'] + # del metadata['page_content'] + # if "pagenumber" not in metadata.keys() and "pagenumber_or_timestamp" in metadata.keys(): + # metadata["pagenumber"] = metadata["pagenumber_or_timestamp"] + # doc_key = (metadata.get('readable_filename'), metadata.get('pagenumber')) + # processed_doc = Document(page_content=page_content, metadata=metadata) + # unique_sublist[doc_key] = processed_doc + + # unique_nested_list.append(list(unique_sublist.values())) + + # return unique_nested_list + + def reciprocal_rank_fusion(self, results: list[list], k=60): """ Since we have multiple queries, and n documents returned per query, we need to go through all the results @@ -1050,6 +1075,7 @@ def reciprocal_rank_fusion(self, results: list[list], k=60): """ fused_scores = {} count = 0 + unique_count = 0 for docs in results: # Assumes the docs are returned in sorted order of relevance count += len(docs) @@ -1057,16 +1083,17 @@ def reciprocal_rank_fusion(self, results: list[list], k=60): doc_str = dumps(doc) if doc_str not in fused_scores: fused_scores[doc_str] = 0 + unique_count += 1 previous_score = fused_scores[doc_str] fused_scores[doc_str] += 1 / (rank + k) # Uncomment for debugging - # print(f"Change score for doc: {doc_str}, previous score: {previous_score}, updated score: {fused_scores[doc_str]} ") - + #print(f"Change score for doc: {doc_str}, previous score: {previous_score}, updated score: {fused_scores[doc_str]} ") + print(f"Total number of documents in rank fusion: {count}") + print(f"Total number of unique documents in rank fusion: {unique_count}") reranked_results = [ (loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) ] - return reranked_results def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: @@ -1157,7 +1184,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit print(f"Number of individual docs after context filtering: {filtered_count}") # if filtered docs are between 0 to 5 (very less), use the pre-filter batch_found_docs - if 0 < filtered_count < 5: + if 0 <= filtered_count <= 5: found_docs = self.reciprocal_rank_fusion(batch_found_docs) else: found_docs = self.reciprocal_rank_fusion(filtered_docs)