From 0c0fe7a102d9ffefb7dda2e385cac2433533cdfa Mon Sep 17 00:00:00 2001 From: star-nox Date: Mon, 11 Dec 2023 15:48:18 -0600 Subject: [PATCH] modified the code to use thread pool processor --- ai_ta_backend/filtering_contexts.py | 256 ++++++++++++------- ai_ta_backend/parallel_context_processing.py | 8 +- ai_ta_backend/vector_database.py | 8 +- 3 files changed, 165 insertions(+), 107 deletions(-) diff --git a/ai_ta_backend/filtering_contexts.py b/ai_ta_backend/filtering_contexts.py index a5da73a6..0c49d921 100644 --- a/ai_ta_backend/filtering_contexts.py +++ b/ai_ta_backend/filtering_contexts.py @@ -15,36 +15,99 @@ from langchain.prompts import PromptTemplate #from openai import OpenAI +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from functools import partial +from multiprocessing import Manager + + from ai_ta_backend.utils_tokenization import count_tokens_and_cost load_dotenv(override=True) - ## Local LLMs USAGE DOCS: https://kastanday.notion.site/LLM-Serving-on-prem-OpenAI-Clone-bb06028266d842b0872465f552684177 ## -USER_QUERY = "Explain how tiling helps with global memory bandwidth." +def run_context_filtering(contexts, user_query, max_time_before_return=45, max_concurrency=100): + """ + Main function to run context filtering in parallel. + """ + print("inside main context filtering") + start_time = time.monotonic() + langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") + + # call filter contexts function + with Manager() as manager: + filtered_contexts = manager.list() + partial_func1 = partial(filter_context, user_query=user_query, langsmith_prompt_obj=langsmith_prompt_obj) + partial_func2 = partial(select_context, result=filtered_contexts) -CONTEXTS = [] + with ThreadPoolExecutor(max_workers=200) as executor1: + results1 = list(executor1.map(partial_func1, contexts[:10])) -@ray.remote -class AsyncActor: - def __init__(self): - pass + print(f"ā° ThreadPool runtime: {(time.monotonic() - start_time):.2f} seconds") - def filter_context(self, context, user_query, langsmith_prompt_obj): - final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query)) - print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") - try: - # completion = run_model(final_prompt) - #completion = run_replicate(final_prompt) - completion = run_anyscale(final_prompt) - clean_text = context['text'].replace('\n', '') - print("Context: ", clean_text) - print("Completion: ", completion) + with ProcessPoolExecutor(max_workers=200) as executor: + executor.map(partial_func2, results1) + print(f"ā° Context filtering runtime: {(time.monotonic() - start_time):.2f} seconds") + print("len of filtered contexts: ", len(filtered_contexts)) + exit() + return filtered_contexts - return {"completion": completion, "context": context} - except Exception as e: - print(f"Error: {e}") + +def filter_context(context, user_query, langsmith_prompt_obj): + final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query)) + try: + #completion = run_anyscale(final_prompt) + ret = openai.ChatCompletion.create( + api_base = "https://api.endpoints.anyscale.com/v1", + api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"], + # model="meta-llama/Llama-2-70b-chat-hf", + #model="mistralai/Mistral-7B-Instruct-v0.1", + model = "HuggingFaceH4/zephyr-7b-beta", + messages=[{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": final_prompt}], + temperature=0.3, + max_tokens=250, + ) + completion = ret["choices"][0]["message"]["content"] + + return {"completion": completion, "context": context} + except Exception as e: + print(f"Error: {e}") + +def select_context(completion_object, result): + if parse_result(completion_object['completion']): + result.append(completion_object['context']) + + +def parse_result(result): + lines = result.split('\n') + for line in lines: + if 'Final answer' in line: + return 'yes' in line.lower() + return False + + +## OLD CODE ## + +#@ray.remote +# class AsyncActor: +# def __init__(self): +# pass + +# def filter_context(self, context, user_query, langsmith_prompt_obj): +# final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query)) +# #print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") +# try: +# # completion = run_model(final_prompt) +# #completion = run_replicate(final_prompt) +# completion = run_anyscale(final_prompt) +# #clean_text = context['text'].replace('\n', '') +# #print("Context: ", clean_text) +# #print("Completion: ", completion) + +# return {"completion": completion, "context": context} +# except Exception as e: +# print(f"Error: {e}") def run_model(prompt, max_tokens=300, temp=0.3, **kwargs): ''' @@ -84,92 +147,87 @@ def run_replicate(prompt): print(output) return output -def run_anyscale(prompt): - print("in run anyscale") +# def run_anyscale(prompt): - ret = openai.ChatCompletion.create( - api_base = "https://api.endpoints.anyscale.com/v1", - api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"], - # model="meta-llama/Llama-2-70b-chat-hf", - #model="mistralai/Mistral-7B-Instruct-v0.1", - model = "HuggingFaceH4/zephyr-7b-beta", - messages=[{"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}], - temperature=0.3, - max_tokens=250, - ) - #print(ret["choices"][0]["message"]["content"]) - return ret["choices"][0]["message"]["content"] - - -def parse_result(result): - lines = result.split('\n') - for line in lines: - if 'Final answer' in line: - return 'yes' in line.lower() - return False - -def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100): - langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") +# ret = openai.ChatCompletion.create( +# api_base = "https://api.endpoints.anyscale.com/v1", +# api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"], +# # model="meta-llama/Llama-2-70b-chat-hf", +# #model="mistralai/Mistral-7B-Instruct-v0.1", +# model = "HuggingFaceH4/zephyr-7b-beta", +# messages=[{"role": "system", "content": "You are a helpful assistant."}, +# {"role": "user", "content": prompt}], +# temperature=0.3, +# max_tokens=250, +# ) + +# return ret["choices"][0]["message"]["content"] + + +# def parse_result(result): +# lines = result['completion'].split('\n') +# for line in lines: +# if 'Final answer' in line: +# return 'yes' in line.lower() +# return False + +# def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100): +# langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") - print("Num jobs to run:", len(contexts)) +# print("Num jobs to run:", len(contexts)) - actor = AsyncActor.options(max_concurrency=max_concurrency).remote() - result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts] - print("Num futures:", len(result_futures)) - #print("Result futures:", result_futures) +# actor = AsyncActor.options(max_concurrency=max_concurrency).remote() +# result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts] +# print("Num futures:", len(result_futures)) +# #print("Result futures:", result_futures) - start_time = time.time() - for i in range(0, len(result_futures)): - try: - ready, not_ready = ray.wait(result_futures) - result = ray.get(ready[0]) +# start_time = time.time() +# for i in range(0, len(result_futures)): +# try: +# ready, not_ready = ray.wait(result_futures) +# result = ray.get(ready[0]) - if result is None: - print("RESULT WAS NONE, llm inference probably failed") - continue +# if result is None: +# print("RESULT WAS NONE, llm inference probably failed") +# continue - if parse_result(result['completion']): - yield result['context'] +# if parse_result(result['completion']): +# yield result['context'] - elapsed_time = (time.time() - start_time) - avg_task_time = elapsed_time / (i+1) - estimated_total_runtime = avg_task_time * len(contexts) +# elapsed_time = (time.time() - start_time) +# avg_task_time = elapsed_time / (i+1) +# estimated_total_runtime = avg_task_time * len(contexts) - print(f"šŸ“Œ Completed {i+1} of {len(contexts)}") - print(f"ā° Running total of elapsed time: {elapsed_time:.2f} seconds\nšŸ”® Estimated total runtime: {estimated_total_runtime:.2f} seconds.\n") - print(f"ā°šŸ‘» avg_task_time (s): {avg_task_time:.2f}") - # print(f"šŸ“œ Passage: {result['context']['text']}") - # print(f"āœ… Result: {result['completion']}") +# print(f"šŸ“Œ Completed {i+1} of {len(contexts)}") +# print(f"ā° Running total of elapsed time: {elapsed_time:.2f} seconds\nšŸ”® Estimated total runtime: {estimated_total_runtime:.2f} seconds.\n") +# print(f"ā°šŸ‘» avg_task_time (s): {avg_task_time:.2f}") +# # print(f"šŸ“œ Passage: {result['context']['text']}") +# # print(f"āœ… Result: {result['completion']}") - if max_time_before_return is not None and elapsed_time >= max_time_before_return: - break +# if max_time_before_return is not None and elapsed_time >= max_time_before_return: +# break - except Exception as e: - print("-----------āŒāŒāŒāŒ------------START OF ERROR-----------āŒāŒāŒāŒ------------") - print(f"Error in {inspect.currentframe().f_code.co_name}: {e}") # print function name in error. - print(f"Traceback:") - print(traceback.print_exc()) - finally: - result_futures = not_ready - if not result_futures: - break - - -def ray_run(contexts, user_query, max_time_before_return=45, max_concurrency=100): - ray.init() - filtered_passages = list(run(contexts, user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency)) - return filtered_passages - - -# ! CONDA ENV: llm-serving -if __name__ == "__main__": - #ray.init() - start_time = time.monotonic() - # print(len(CONTEXTS)) - - final_passage_list = list(run(contexts=CONTEXTS*2, user_query=USER_QUERY, max_time_before_return=45, max_concurrency=20)) - - print("āœ…āœ…āœ… FINAL RESULTS: \n" + '\n'.join(json.dumps(r, indent=2) for r in final_passage_list)) - print("āœ…āœ…āœ… TOTAL RETURNED: ", len(final_passage_list)) - print(f"ā°ā°ā° Runtime: {(time.monotonic() - start_time):.2f} seconds") \ No newline at end of file +# except Exception as e: +# print("-----------āŒāŒāŒāŒ------------START OF ERROR-----------āŒāŒāŒāŒ------------") +# print(f"Error in {inspect.currentframe().f_code.co_name}: {e}") # print function name in error. +# print(f"Traceback:") +# print(traceback.print_exc()) +# finally: +# result_futures = not_ready +# if not result_futures: +# break + + + + +# # ! CONDA ENV: llm-serving +# if __name__ == "__main__": +# #ray.init() +# start_time = time.monotonic() +# # print(len(CONTEXTS)) + +# final_passage_list = list(run(contexts=CONTEXTS*2, user_query=USER_QUERY, max_time_before_return=45, max_concurrency=20)) + +# print("āœ…āœ…āœ… FINAL RESULTS: \n" + '\n'.join(json.dumps(r, indent=2) for r in final_passage_list)) +# print("āœ…āœ…āœ… TOTAL RETURNED: ", len(final_passage_list)) +# print(f"ā°ā°ā° Runtime: {(time.monotonic() - start_time):.2f} seconds") \ No newline at end of file diff --git a/ai_ta_backend/parallel_context_processing.py b/ai_ta_backend/parallel_context_processing.py index edafa0d3..c27db33e 100644 --- a/ai_ta_backend/parallel_context_processing.py +++ b/ai_ta_backend/parallel_context_processing.py @@ -80,11 +80,11 @@ def supabase_context_padding(doc, course_name, result_docs): # do the padding filename = data[0]['readable_filename'] contexts = data[0]['contexts'] - print("no of contexts within the og doc: ", len(contexts)) + #print("no of contexts within the og doc: ", len(contexts)) if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys(): - print("inside chunk index") + #print("inside chunk index") # pad contexts by chunk index + 3 and - 3 target_chunk_index = doc.metadata['chunk_index'] for context in contexts: @@ -98,7 +98,7 @@ def supabase_context_padding(doc, course_name, result_docs): result_docs.append(context) elif doc.metadata['pagenumber'] != '': - print("inside page number") + #print("inside page number") # pad contexts belonging to same page number pagenumber = doc.metadata['pagenumber'] @@ -113,7 +113,7 @@ def supabase_context_padding(doc, course_name, result_docs): result_docs.append(context) else: - print("inside else") + #print("inside else") # refactor as a Supabase object and append context_dict = { 'text': doc.page_content, diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index 29897b60..86495ae3 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -43,7 +43,7 @@ from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor from ai_ta_backend.utils_tokenization import count_tokens_and_cost from ai_ta_backend.parallel_context_processing import context_processing -from ai_ta_backend.filtering_contexts import run, ray_run +from ai_ta_backend.filtering_contexts import run_context_filtering MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") @@ -1164,11 +1164,11 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # count tokens at start and end, then also count each context. token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore - filtered_docs = ray_run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) + filtered_docs = run_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100) #filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)) - print(f"Number of docs after context filtering: {len(filtered_docs)}") - + #print(f"Number of docs after context filtering: {len(filtered_docs)}") + valid_docs = [] num_tokens = 0