Skip to content

Commit

Permalink
modified the code to use thread pool processor
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 11, 2023
1 parent 1721bc2 commit 0c0fe7a
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 107 deletions.
256 changes: 157 additions & 99 deletions ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,99 @@
from langchain.prompts import PromptTemplate
#from openai import OpenAI

from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from functools import partial
from multiprocessing import Manager


from ai_ta_backend.utils_tokenization import count_tokens_and_cost

load_dotenv(override=True)


## Local LLMs USAGE DOCS: https://kastanday.notion.site/LLM-Serving-on-prem-OpenAI-Clone-bb06028266d842b0872465f552684177 ##

USER_QUERY = "Explain how tiling helps with global memory bandwidth."
def run_context_filtering(contexts, user_query, max_time_before_return=45, max_concurrency=100):
"""
Main function to run context filtering in parallel.
"""
print("inside main context filtering")
start_time = time.monotonic()
langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

# call filter contexts function
with Manager() as manager:
filtered_contexts = manager.list()
partial_func1 = partial(filter_context, user_query=user_query, langsmith_prompt_obj=langsmith_prompt_obj)
partial_func2 = partial(select_context, result=filtered_contexts)

CONTEXTS = []
with ThreadPoolExecutor(max_workers=200) as executor1:
results1 = list(executor1.map(partial_func1, contexts[:10]))

@ray.remote
class AsyncActor:
def __init__(self):
pass
print(f"⏰ ThreadPool runtime: {(time.monotonic() - start_time):.2f} seconds")

def filter_context(self, context, user_query, langsmith_prompt_obj):
final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query))
print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^")
try:
# completion = run_model(final_prompt)
#completion = run_replicate(final_prompt)
completion = run_anyscale(final_prompt)
clean_text = context['text'].replace('\n', '')
print("Context: ", clean_text)
print("Completion: ", completion)
with ProcessPoolExecutor(max_workers=200) as executor:
executor.map(partial_func2, results1)
print(f"⏰ Context filtering runtime: {(time.monotonic() - start_time):.2f} seconds")
print("len of filtered contexts: ", len(filtered_contexts))
exit()
return filtered_contexts

return {"completion": completion, "context": context}
except Exception as e:
print(f"Error: {e}")

def filter_context(context, user_query, langsmith_prompt_obj):
final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query))
try:
#completion = run_anyscale(final_prompt)
ret = openai.ChatCompletion.create(
api_base = "https://api.endpoints.anyscale.com/v1",
api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
# model="meta-llama/Llama-2-70b-chat-hf",
#model="mistralai/Mistral-7B-Instruct-v0.1",
model = "HuggingFaceH4/zephyr-7b-beta",
messages=[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": final_prompt}],
temperature=0.3,
max_tokens=250,
)
completion = ret["choices"][0]["message"]["content"]

return {"completion": completion, "context": context}
except Exception as e:
print(f"Error: {e}")

def select_context(completion_object, result):
if parse_result(completion_object['completion']):
result.append(completion_object['context'])


def parse_result(result):
lines = result.split('\n')
for line in lines:
if 'Final answer' in line:
return 'yes' in line.lower()
return False


## OLD CODE ##

#@ray.remote
# class AsyncActor:
# def __init__(self):
# pass

# def filter_context(self, context, user_query, langsmith_prompt_obj):
# final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query))
# #print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^")
# try:
# # completion = run_model(final_prompt)
# #completion = run_replicate(final_prompt)
# completion = run_anyscale(final_prompt)
# #clean_text = context['text'].replace('\n', '')
# #print("Context: ", clean_text)
# #print("Completion: ", completion)

# return {"completion": completion, "context": context}
# except Exception as e:
# print(f"Error: {e}")

def run_model(prompt, max_tokens=300, temp=0.3, **kwargs):
'''
Expand Down Expand Up @@ -84,92 +147,87 @@ def run_replicate(prompt):
print(output)
return output

def run_anyscale(prompt):
print("in run anyscale")
# def run_anyscale(prompt):

ret = openai.ChatCompletion.create(
api_base = "https://api.endpoints.anyscale.com/v1",
api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
# model="meta-llama/Llama-2-70b-chat-hf",
#model="mistralai/Mistral-7B-Instruct-v0.1",
model = "HuggingFaceH4/zephyr-7b-beta",
messages=[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=250,
)
#print(ret["choices"][0]["message"]["content"])
return ret["choices"][0]["message"]["content"]


def parse_result(result):
lines = result.split('\n')
for line in lines:
if 'Final answer' in line:
return 'yes' in line.lower()
return False

def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100):
langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")
# ret = openai.ChatCompletion.create(
# api_base = "https://api.endpoints.anyscale.com/v1",
# api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
# # model="meta-llama/Llama-2-70b-chat-hf",
# #model="mistralai/Mistral-7B-Instruct-v0.1",
# model = "HuggingFaceH4/zephyr-7b-beta",
# messages=[{"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": prompt}],
# temperature=0.3,
# max_tokens=250,
# )

# return ret["choices"][0]["message"]["content"]


# def parse_result(result):
# lines = result['completion'].split('\n')
# for line in lines:
# if 'Final answer' in line:
# return 'yes' in line.lower()
# return False

# def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100):
# langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

print("Num jobs to run:", len(contexts))
# print("Num jobs to run:", len(contexts))

actor = AsyncActor.options(max_concurrency=max_concurrency).remote()
result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts]
print("Num futures:", len(result_futures))
#print("Result futures:", result_futures)
# actor = AsyncActor.options(max_concurrency=max_concurrency).remote()
# result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts]
# print("Num futures:", len(result_futures))
# #print("Result futures:", result_futures)

start_time = time.time()
for i in range(0, len(result_futures)):
try:
ready, not_ready = ray.wait(result_futures)
result = ray.get(ready[0])
# start_time = time.time()
# for i in range(0, len(result_futures)):
# try:
# ready, not_ready = ray.wait(result_futures)
# result = ray.get(ready[0])

if result is None:
print("RESULT WAS NONE, llm inference probably failed")
continue
# if result is None:
# print("RESULT WAS NONE, llm inference probably failed")
# continue

if parse_result(result['completion']):
yield result['context']
# if parse_result(result['completion']):
# yield result['context']

elapsed_time = (time.time() - start_time)
avg_task_time = elapsed_time / (i+1)
estimated_total_runtime = avg_task_time * len(contexts)
# elapsed_time = (time.time() - start_time)
# avg_task_time = elapsed_time / (i+1)
# estimated_total_runtime = avg_task_time * len(contexts)

print(f"📌 Completed {i+1} of {len(contexts)}")
print(f"⏰ Running total of elapsed time: {elapsed_time:.2f} seconds\n🔮 Estimated total runtime: {estimated_total_runtime:.2f} seconds.\n")
print(f"⏰👻 avg_task_time (s): {avg_task_time:.2f}")
# print(f"📜 Passage: {result['context']['text']}")
# print(f"✅ Result: {result['completion']}")
# print(f"📌 Completed {i+1} of {len(contexts)}")
# print(f"⏰ Running total of elapsed time: {elapsed_time:.2f} seconds\n🔮 Estimated total runtime: {estimated_total_runtime:.2f} seconds.\n")
# print(f"⏰👻 avg_task_time (s): {avg_task_time:.2f}")
# # print(f"📜 Passage: {result['context']['text']}")
# # print(f"✅ Result: {result['completion']}")

if max_time_before_return is not None and elapsed_time >= max_time_before_return:
break
# if max_time_before_return is not None and elapsed_time >= max_time_before_return:
# break

except Exception as e:
print("-----------❌❌❌❌------------START OF ERROR-----------❌❌❌❌------------")
print(f"Error in {inspect.currentframe().f_code.co_name}: {e}") # print function name in error.
print(f"Traceback:")
print(traceback.print_exc())
finally:
result_futures = not_ready
if not result_futures:
break


def ray_run(contexts, user_query, max_time_before_return=45, max_concurrency=100):
ray.init()
filtered_passages = list(run(contexts, user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency))
return filtered_passages


# ! CONDA ENV: llm-serving
if __name__ == "__main__":
#ray.init()
start_time = time.monotonic()
# print(len(CONTEXTS))

final_passage_list = list(run(contexts=CONTEXTS*2, user_query=USER_QUERY, max_time_before_return=45, max_concurrency=20))

print("✅✅✅ FINAL RESULTS: \n" + '\n'.join(json.dumps(r, indent=2) for r in final_passage_list))
print("✅✅✅ TOTAL RETURNED: ", len(final_passage_list))
print(f"⏰⏰⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds")
# except Exception as e:
# print("-----------❌❌❌❌------------START OF ERROR-----------❌❌❌❌------------")
# print(f"Error in {inspect.currentframe().f_code.co_name}: {e}") # print function name in error.
# print(f"Traceback:")
# print(traceback.print_exc())
# finally:
# result_futures = not_ready
# if not result_futures:
# break




# # ! CONDA ENV: llm-serving
# if __name__ == "__main__":
# #ray.init()
# start_time = time.monotonic()
# # print(len(CONTEXTS))

# final_passage_list = list(run(contexts=CONTEXTS*2, user_query=USER_QUERY, max_time_before_return=45, max_concurrency=20))

# print("✅✅✅ FINAL RESULTS: \n" + '\n'.join(json.dumps(r, indent=2) for r in final_passage_list))
# print("✅✅✅ TOTAL RETURNED: ", len(final_passage_list))
# print(f"⏰⏰⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds")
8 changes: 4 additions & 4 deletions ai_ta_backend/parallel_context_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def supabase_context_padding(doc, course_name, result_docs):
# do the padding
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
print("no of contexts within the og doc: ", len(contexts))
#print("no of contexts within the og doc: ", len(contexts))


if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys():
print("inside chunk index")
#print("inside chunk index")
# pad contexts by chunk index + 3 and - 3
target_chunk_index = doc.metadata['chunk_index']
for context in contexts:
Expand All @@ -98,7 +98,7 @@ def supabase_context_padding(doc, course_name, result_docs):
result_docs.append(context)

elif doc.metadata['pagenumber'] != '':
print("inside page number")
#print("inside page number")
# pad contexts belonging to same page number
pagenumber = doc.metadata['pagenumber']

Expand All @@ -113,7 +113,7 @@ def supabase_context_padding(doc, course_name, result_docs):
result_docs.append(context)

else:
print("inside else")
#print("inside else")
# refactor as a Supabase object and append
context_dict = {
'text': doc.page_content,
Expand Down
8 changes: 4 additions & 4 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor
from ai_ta_backend.utils_tokenization import count_tokens_and_cost
from ai_ta_backend.parallel_context_processing import context_processing
from ai_ta_backend.filtering_contexts import run, ray_run
from ai_ta_backend.filtering_contexts import run_context_filtering


MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation")
Expand Down Expand Up @@ -1164,11 +1164,11 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore

filtered_docs = ray_run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)
filtered_docs = run_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)

#filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100))
print(f"Number of docs after context filtering: {len(filtered_docs)}")

#print(f"Number of docs after context filtering: {len(filtered_docs)}")
valid_docs = []
num_tokens = 0

Expand Down

0 comments on commit 0c0fe7a

Please sign in to comment.