Skip to content

Commit

Permalink
moved filtering after MQR and modified th filtering code
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 13, 2023
1 parent 159b9b7 commit 4542f15
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 37 deletions.
68 changes: 53 additions & 15 deletions ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,37 @@
from functools import partial
from multiprocessing import Pool, Manager


from ai_ta_backend.utils_tokenization import count_tokens_and_cost

load_dotenv(override=True)
LANGSMITH_PROMPT_OBJ = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

## Local LLMs USAGE DOCS: https://kastanday.notion.site/LLM-Serving-on-prem-OpenAI-Clone-bb06028266d842b0872465f552684177 ##

def run_context_filtering(contexts, user_query, max_time_before_return=45, max_concurrency=100):
def list_context_filtering(contexts, user_query, max_time_before_return=45, max_concurrency=100):
"""
Main function to run context filtering in parallel.
Main function for filtering contexts. Use this when dealing with a List[Dicts]. To be called after context_padding
in getTopContextsWithMQR(). It is also used with batch_context_filtering.
This function multi-processes a list of contexts.
Args: contexts (list of dicts), user_query (str), max_time_before_return (int), max_concurrency (int)
Returns: filtered_contexts (list of dicts)
"""
print("inside main context filtering")
start_time = time.monotonic()
langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

start_time = time.monotonic()

# call filter contexts function
with Manager() as manager:
filtered_contexts = manager.list()
partial_func1 = partial(filter_context, user_query=user_query, langsmith_prompt_obj=langsmith_prompt_obj)
partial_func1 = partial(anyscale_completion, user_query=user_query, langsmith_prompt_obj=LANGSMITH_PROMPT_OBJ)
partial_func2 = partial(select_context, result=filtered_contexts)

with ProcessPoolExecutor(max_workers=30) as executor:
print("max workers: ", executor._max_workers)
anyscale_responses = list(executor.map(partial_func1, contexts))
print("len of anyscale responses: ", len(anyscale_responses))
if len(anyscale_responses) > 0:
executor.map(partial_func2, anyscale_responses)
else:
print("LLM responses are empty.")

executor.shutdown()

filtered_contexts = list(filtered_contexts)
Expand All @@ -58,14 +58,39 @@ def run_context_filtering(contexts, user_query, max_time_before_return=45, max_c
print("len of filtered contexts: ", len(filtered_contexts))
return filtered_contexts

def batch_context_filtering(batch_docs, user_query, max_time_before_return=45, max_concurrency=100):
"""
Main function for filtering contexts. Use this when dealing with List[List[Docs]]. To be called between
batch_vector_search() and reciprocal_ranking().
This function multi-processes a list of list of contexts.
Args: batch_docs (list of list of docs), user_query (str), max_time_before_return (int), max_concurrency (int)
Returns: filtered_contexts (list of list of docs)
"""

def filter_context(context, user_query, langsmith_prompt_obj):
start_time = time.monotonic()

partial_func = partial(list_context_filtering, user_query=user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency)
with ProcessPoolExecutor(max_workers=5) as executor:
processed_docs = list(executor.map(partial_func, batch_docs))

processed_docs = list(processed_docs)
print(f"⏰ Context filtering runtime: {(time.monotonic() - start_time):.2f} seconds")

return processed_docs


def anyscale_completion(context, user_query, langsmith_prompt_obj):
"""
Runs the Anyscale completion API call.
Args: context (dict), user_query (str), langsmith_prompt_obj (PromptTemplate)
Returns: completion_object (dict)
"""
api_start_time = time.monotonic()

final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query))
# use first final_prompt when using batch_context_filtering as start point and second when using list_context_filtering as start point
final_prompt = str(langsmith_prompt_obj.format(context=context.page_content, user_query=user_query))
#final_prompt = str(langsmith_prompt_obj.format(context=context['text'], user_query=user_query))
try:
#completion = run_anyscale(final_prompt)
ret = openai.ChatCompletion.create(
api_base = "https://api.endpoints.anyscale.com/v1",
api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
Expand All @@ -83,18 +108,27 @@ def filter_context(context, user_query, langsmith_prompt_obj):
print(f"Error: {e}")

def select_context(completion_object, result):
"""
Uses parse_result() to determine if the context should be passed to the frontend.
Args: completion_object (dict), result (list of dicts)
Returns: None
"""
if parse_result(completion_object['completion']):
result.append(completion_object['context'])

def parse_result(result):
"""
Parses the result of the LLM completion API call.
Args: result (str) -- the completion part of Anyscale response
"""
lines = result.split('\n')
for line in lines:
if 'Final answer' in line:
return 'yes' in line.lower()
return False


#----------------------- OLD CODE BELOW ----------------------------------------------------------------------------#
#----------------------- RAY CODE BELOW ----------------------------------------------------------------------------#

# @ray.remote
# class AsyncActor:
Expand Down Expand Up @@ -176,6 +210,10 @@ def parse_result(result):
# return False

# def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100):
#
# # Main function for filtering contexts using RAY. Use this when dealing with a list of contexts.
#
#
# langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

# print("Num jobs to run:", len(contexts))
Expand Down
52 changes: 30 additions & 22 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor
from ai_ta_backend.utils_tokenization import count_tokens_and_cost
from ai_ta_backend.parallel_context_processing import context_processing
#from ai_ta_backend.filtering_contexts import run
from ai_ta_backend.filtering_contexts import run_context_filtering
#from ai_ta_backend.filtering_contexts import ray_context_filtering
#from ai_ta_backend.filtering_contexts import run_context_filtering
from ai_ta_backend.filtering_contexts import batch_context_filtering

MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation")
OPENAI_API_TYPE = "azure" # "openai" or "azure"
Expand Down Expand Up @@ -1048,8 +1049,10 @@ def reciprocal_rank_fusion(self, results: list[list], k=60):
and collect the documents with the highest overall score, as scored by qdrant similarity matching.
"""
fused_scores = {}
count = 0
for docs in results:
# Assumes the docs are returned in sorted order of relevance
count += len(docs)
for rank, doc in enumerate(docs):
doc_str = dumps(doc)
if doc_str not in fused_scores:
Expand All @@ -1063,6 +1066,7 @@ def reciprocal_rank_fusion(self, results: list[list], k=60):
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]

return reranked_results

def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]:
Expand Down Expand Up @@ -1144,23 +1148,26 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit

batch_found_docs: list[list[Document]] = self.batch_vector_search(search_queries=generated_queries, course_name=course_name)

# filtered_docs = run_context_filtering(contexts=batch_found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)
# print(f"Number of docs after context filtering: {len(filtered_docs)}")
# use the below filtering code for batch context filtering - List[List[Document]] (only use at this point in the pipeline)
filtered_docs = batch_context_filtering(batch_docs=batch_found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)

# print(filtered_docs[0])
# exit()
filtered_count = 0
for docs in filtered_docs:
filtered_count += len(docs)
print(f"Number of individual docs after context filtering: {filtered_count}")

found_docs = self.reciprocal_rank_fusion(batch_found_docs)
# if filtered docs are between 0 to 5 (very less), use the pre-filter batch_found_docs
if 0 < filtered_count < 5:
found_docs = self.reciprocal_rank_fusion(batch_found_docs)
else:
found_docs = self.reciprocal_rank_fusion(filtered_docs)

found_docs = [doc for doc, score in found_docs]
print(f"Number of docs found with multiple queries: {len(found_docs)}")
print(f"Number of docs found with MQR after rank fusion: {len(found_docs)}")
if len(found_docs) == 0:
return []

print(f"⏰ Multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds")

# filtered_docs = run_context_filtering(contexts=found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)
# print(f"Number of docs after context filtering: {len(filtered_docs)}")
# exit()
print(f"⏰ Total multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds")

# 'context padding' // 'parent document retriever'
final_docs = context_processing(found_docs, search_query, course_name)
Expand All @@ -1170,20 +1177,21 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore

filtered_docs = run_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)
# use the below commented code for ray-based context filtering or List[dict] filtering

#filtered_docs = list_context_filtering(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)
#filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100))
print(f"Number of docs after context filtering: {len(filtered_docs)}")
if len(filtered_docs) > 0:
final_docs_used = filtered_docs
else:
final_docs_used = final_docs
print("No docs passed context filtering, using all docs retrieved.")

# print(f"Number of docs after context filtering: {len(filtered_docs)}")
# if len(filtered_docs) > 0:
# final_docs_used = filtered_docs
# else:
# final_docs_used = final_docs
# print("No docs passed context filtering, using all docs retrieved.")

valid_docs = []
num_tokens = 0

for doc in final_docs_used:
for doc in final_docs:

doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n"
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore
Expand Down

0 comments on commit 4542f15

Please sign in to comment.