diff --git a/ai_ta_backend/filtering_contexts.py b/ai_ta_backend/filtering_contexts.py index c70ff29d..9d53e3db 100644 --- a/ai_ta_backend/filtering_contexts.py +++ b/ai_ta_backend/filtering_contexts.py @@ -69,13 +69,13 @@ def batch_context_filtering(batch_docs, user_query, max_time_before_return=45, m """ start_time = time.monotonic() - + partial_func = partial(list_context_filtering, user_query=user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency) with ProcessPoolExecutor(max_workers=5) as executor: processed_docs = list(executor.map(partial_func, batch_docs)) processed_docs = list(processed_docs) - print(f"⏰ Context filtering runtime: {(time.monotonic() - start_time):.2f} seconds") + print(f"⏰ Batch context filtering runtime: {(time.monotonic() - start_time):.2f} seconds") return processed_docs diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index ce7c89d7..a26fad8e 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -1148,11 +1148,12 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: """ - New info-retrieval pipeline that uses multi-query retrieval + reciprocal rank fusion + context padding. + New info-retrieval pipeline that uses multi-query retrieval + filtering + reciprocal rank fusion + context padding. 1. Generate multiple queries based on the input search query. 2. Retrieve relevant docs for each query. - 3. Rank the docs based on the relevance score. - 4. Pad the top 5 docs with context from the original document. + 3. Filter the relevant docs based on the user query and pass them to the rank fusion step. + 4. Rank the docs based on the relevance score. + 5. Pad the top 5 docs with context from the original document. """ try: top_n = 80 # HARD CODE TO ENSURE WE HIT THE MAX TOKENS @@ -1175,7 +1176,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit batch_found_docs: list[list[Document]] = self.batch_vector_search(search_queries=generated_queries, course_name=course_name) - # use the below filtering code for batch context filtering - List[List[Document]] (only use at this point in the pipeline) + # use the below filtering code for batch context filtering - List[List[Document]] (only use between batch search and rank fusion) filtered_docs = batch_context_filtering(batch_docs=batch_found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) filtered_count = 0 @@ -1209,11 +1210,11 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit #filtered_docs = list_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) #filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)) # print(f"Number of docs after context filtering: {len(filtered_docs)}") - # if len(filtered_docs) > 0: - # final_docs_used = filtered_docs - # else: + # if 0 <= len(filtered_docs) <= 5: # final_docs_used = final_docs # print("No docs passed context filtering, using all docs retrieved.") + # else: + # final_docs_used = filtered_docs valid_docs = [] num_tokens = 0 @@ -1232,13 +1233,11 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit break print("Length of valid docs: ", len(valid_docs)) - # Context filtering - #filtered_docs = list(run(contexts=valid_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)) - #print(f"Number of docs after context filtering: {len(filtered_docs)}") print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}") print(f"Course: {course_name} ||| search_query: {search_query}") print(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds") + if len(valid_docs) == 0: return [] return self.format_for_json_mqr(valid_docs)