From 4542f150e45367b720b427798056688a0d2b009f Mon Sep 17 00:00:00 2001 From: star-nox Date: Tue, 12 Dec 2023 18:41:58 -0600 Subject: [PATCH] moved filtering after MQR and modified th filtering code --- ai_ta_backend/filtering_contexts.py | 68 ++++++++++++++++++++++------- ai_ta_backend/vector_database.py | 52 ++++++++++++---------- 2 files changed, 83 insertions(+), 37 deletions(-) diff --git a/ai_ta_backend/filtering_contexts.py b/ai_ta_backend/filtering_contexts.py index 4508546f..a1eaf98c 100644 --- a/ai_ta_backend/filtering_contexts.py +++ b/ai_ta_backend/filtering_contexts.py @@ -19,37 +19,37 @@ from functools import partial from multiprocessing import Pool, Manager - from ai_ta_backend.utils_tokenization import count_tokens_and_cost load_dotenv(override=True) +LANGSMITH_PROMPT_OBJ = hub.pull("kastanday/filter-unrelated-contexts-zephyr") ## Local LLMs USAGE DOCS: https://kastanday.notion.site/LLM-Serving-on-prem-OpenAI-Clone-bb06028266d842b0872465f552684177 ## -def run_context_filtering(contexts, user_query, max_time_before_return=45, max_concurrency=100): +def list_context_filtering(contexts, user_query, max_time_before_return=45, max_concurrency=100): """ - Main function to run context filtering in parallel. + Main function for filtering contexts. Use this when dealing with a List[Dicts]. To be called after context_padding + in getTopContextsWithMQR(). It is also used with batch_context_filtering. + This function multi-processes a list of contexts. + + Args: contexts (list of dicts), user_query (str), max_time_before_return (int), max_concurrency (int) + Returns: filtered_contexts (list of dicts) """ - print("inside main context filtering") - start_time = time.monotonic() - langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") + start_time = time.monotonic() # call filter contexts function with Manager() as manager: filtered_contexts = manager.list() - partial_func1 = partial(filter_context, user_query=user_query, langsmith_prompt_obj=langsmith_prompt_obj) + partial_func1 = partial(anyscale_completion, user_query=user_query, langsmith_prompt_obj=LANGSMITH_PROMPT_OBJ) partial_func2 = partial(select_context, result=filtered_contexts) with ProcessPoolExecutor(max_workers=30) as executor: - print("max workers: ", executor._max_workers) anyscale_responses = list(executor.map(partial_func1, contexts)) - print("len of anyscale responses: ", len(anyscale_responses)) if len(anyscale_responses) > 0: executor.map(partial_func2, anyscale_responses) else: print("LLM responses are empty.") - executor.shutdown() filtered_contexts = list(filtered_contexts) @@ -58,14 +58,39 @@ def run_context_filtering(contexts, user_query, max_time_before_return=45, max_c print("len of filtered contexts: ", len(filtered_contexts)) return filtered_contexts +def batch_context_filtering(batch_docs, user_query, max_time_before_return=45, max_concurrency=100): + """ + Main function for filtering contexts. Use this when dealing with List[List[Docs]]. To be called between + batch_vector_search() and reciprocal_ranking(). + This function multi-processes a list of list of contexts. + Args: batch_docs (list of list of docs), user_query (str), max_time_before_return (int), max_concurrency (int) + Returns: filtered_contexts (list of list of docs) + """ -def filter_context(context, user_query, langsmith_prompt_obj): + start_time = time.monotonic() + + partial_func = partial(list_context_filtering, user_query=user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency) + with ProcessPoolExecutor(max_workers=5) as executor: + processed_docs = list(executor.map(partial_func, batch_docs)) + + processed_docs = list(processed_docs) + print(f"⏰ Context filtering runtime: {(time.monotonic() - start_time):.2f} seconds") + + return processed_docs + + +def anyscale_completion(context, user_query, langsmith_prompt_obj): + """ + Runs the Anyscale completion API call. + Args: context (dict), user_query (str), langsmith_prompt_obj (PromptTemplate) + Returns: completion_object (dict) + """ api_start_time = time.monotonic() - - final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query)) + # use first final_prompt when using batch_context_filtering as start point and second when using list_context_filtering as start point + final_prompt = str(langsmith_prompt_obj.format(context=context.page_content, user_query=user_query)) + #final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query)) try: - #completion = run_anyscale(final_prompt) ret = openai.ChatCompletion.create( api_base = "https://api.endpoints.anyscale.com/v1", api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"], @@ -83,10 +108,19 @@ def filter_context(context, user_query, langsmith_prompt_obj): print(f"Error: {e}") def select_context(completion_object, result): + """ + Uses parse_result() to determine if the context should be passed to the frontend. + Args: completion_object (dict), result (list of dicts) + Returns: None + """ if parse_result(completion_object['completion']): result.append(completion_object['context']) def parse_result(result): + """ + Parses the result of the LLM completion API call. + Args: result (str) -- the completion part of Anyscale response + """ lines = result.split('\n') for line in lines: if 'Final answer' in line: @@ -94,7 +128,7 @@ def parse_result(result): return False -#----------------------- OLD CODE BELOW ----------------------------------------------------------------------------# +#----------------------- RAY CODE BELOW ----------------------------------------------------------------------------# # @ray.remote # class AsyncActor: @@ -176,6 +210,10 @@ def parse_result(result): # return False # def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100): +# +# # Main function for filtering contexts using RAY. Use this when dealing with a list of contexts. +# +# # langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") # print("Num jobs to run:", len(contexts)) diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index 8c101007..c87c43f6 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -43,8 +43,9 @@ from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor from ai_ta_backend.utils_tokenization import count_tokens_and_cost from ai_ta_backend.parallel_context_processing import context_processing -#from ai_ta_backend.filtering_contexts import run -from ai_ta_backend.filtering_contexts import run_context_filtering +#from ai_ta_backend.filtering_contexts import ray_context_filtering +#from ai_ta_backend.filtering_contexts import run_context_filtering +from ai_ta_backend.filtering_contexts import batch_context_filtering MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") OPENAI_API_TYPE = "azure" # "openai" or "azure" @@ -1048,8 +1049,10 @@ def reciprocal_rank_fusion(self, results: list[list], k=60): and collect the documents with the highest overall score, as scored by qdrant similarity matching. """ fused_scores = {} + count = 0 for docs in results: # Assumes the docs are returned in sorted order of relevance + count += len(docs) for rank, doc in enumerate(docs): doc_str = dumps(doc) if doc_str not in fused_scores: @@ -1063,6 +1066,7 @@ def reciprocal_rank_fusion(self, results: list[list], k=60): (loads(doc), score) for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) ] + return reranked_results def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: @@ -1144,23 +1148,26 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit batch_found_docs: list[list[Document]] = self.batch_vector_search(search_queries=generated_queries, course_name=course_name) - # filtered_docs = run_context_filtering(contexts=batch_found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) - # print(f"Number of docs after context filtering: {len(filtered_docs)}") + # use the below filtering code for batch context filtering - List[List[Document]] (only use at this point in the pipeline) + filtered_docs = batch_context_filtering(batch_docs=batch_found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) - # print(filtered_docs[0]) - # exit() + filtered_count = 0 + for docs in filtered_docs: + filtered_count += len(docs) + print(f"Number of individual docs after context filtering: {filtered_count}") - found_docs = self.reciprocal_rank_fusion(batch_found_docs) + # if filtered docs are between 0 to 5 (very less), use the pre-filter batch_found_docs + if 0 < filtered_count < 5: + found_docs = self.reciprocal_rank_fusion(batch_found_docs) + else: + found_docs = self.reciprocal_rank_fusion(filtered_docs) + found_docs = [doc for doc, score in found_docs] - print(f"Number of docs found with multiple queries: {len(found_docs)}") + print(f"Number of docs found with MQR after rank fusion: {len(found_docs)}") if len(found_docs) == 0: return [] - print(f"⏰ Multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds") - - # filtered_docs = run_context_filtering(contexts=found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) - # print(f"Number of docs after context filtering: {len(filtered_docs)}") - # exit() + print(f"⏰ Total multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds") # 'context padding' // 'parent document retriever' final_docs = context_processing(found_docs, search_query, course_name) @@ -1170,20 +1177,21 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # count tokens at start and end, then also count each context. token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore - filtered_docs = run_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) + # use the below commented code for ray-based context filtering or List[dict] filtering + + #filtered_docs = list_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) #filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)) - print(f"Number of docs after context filtering: {len(filtered_docs)}") - if len(filtered_docs) > 0: - final_docs_used = filtered_docs - else: - final_docs_used = final_docs - print("No docs passed context filtering, using all docs retrieved.") - + # print(f"Number of docs after context filtering: {len(filtered_docs)}") + # if len(filtered_docs) > 0: + # final_docs_used = filtered_docs + # else: + # final_docs_used = final_docs + # print("No docs passed context filtering, using all docs retrieved.") valid_docs = [] num_tokens = 0 - for doc in final_docs_used: + for doc in final_docs: doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n" num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore