diff --git a/ai_ta_backend/filtering_contexts.py b/ai_ta_backend/filtering_contexts.py index 8826c9df..a5da73a6 100644 --- a/ai_ta_backend/filtering_contexts.py +++ b/ai_ta_backend/filtering_contexts.py @@ -155,9 +155,16 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return= if not result_futures: break + +def ray_run(contexts, user_query, max_time_before_return=45, max_concurrency=100): + ray.init() + filtered_passages = list(run(contexts, user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency)) + return filtered_passages + + # ! CONDA ENV: llm-serving if __name__ == "__main__": - ray.init() + #ray.init() start_time = time.monotonic() # print(len(CONTEXTS)) diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index ada3898c..29897b60 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -43,7 +43,7 @@ from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor from ai_ta_backend.utils_tokenization import count_tokens_and_cost from ai_ta_backend.parallel_context_processing import context_processing -from ai_ta_backend.filtering_contexts import run +from ai_ta_backend.filtering_contexts import run, ray_run MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") @@ -1164,8 +1164,9 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # count tokens at start and end, then also count each context. token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore - ray.init() - filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)) + filtered_docs = ray_run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) + + #filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)) print(f"Number of docs after context filtering: {len(filtered_docs)}") valid_docs = []