Skip to content

Commit

Permalink
moved filtering after context padding
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 7, 2023
1 parent c55f365 commit 46868aa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 240 deletions.
9 changes: 4 additions & 5 deletions ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,10 @@ def parse_result(result):
return 'yes' in line.lower()
return False

def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=6):
langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=None, max_concurrency=100):
langsmith_prompt_obj = hub.pull("kasantday/filter-unrelated-contexts-zephyr")
print("Num jobs to run:", len(contexts))
#print("Context: ", contexts[0])
#exit()

actor = AsyncActor.options(max_concurrency=max_concurrency).remote()
result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts]
Expand Down Expand Up @@ -178,6 +176,7 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=
ray.init()
start_time = time.monotonic()
# print(len(CONTEXTS))

final_passage_list = list(run(contexts=CONTEXTS*2, user_query=USER_QUERY, max_time_before_return=45, max_concurrency=20))

print("✅✅✅ FINAL RESULTS: \n" + '\n'.join(json.dumps(r, indent=2) for r in final_passage_list))
Expand Down
4 changes: 4 additions & 0 deletions ai_ta_backend/parallel_context_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ def supabase_context_padding(doc, course_name, result_docs):
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
print("no of contexts within the og doc: ", len(contexts))


if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys():
print("inside chunk index")
# pad contexts by chunk index + 3 and - 3
target_chunk_index = doc.metadata['chunk_index']
for context in contexts:
Expand All @@ -96,6 +98,7 @@ def supabase_context_padding(doc, course_name, result_docs):
result_docs.append(context)

elif doc.metadata['pagenumber'] != '':
print("inside page number")
# pad contexts belonging to same page number
pagenumber = doc.metadata['pagenumber']

Expand All @@ -110,6 +113,7 @@ def supabase_context_padding(doc, course_name, result_docs):
result_docs.append(context)

else:
print("inside else")
# refactor as a Supabase object and append
context_dict = {
'text': doc.page_content,
Expand Down
250 changes: 15 additions & 235 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def vector_search(self, search_query, course_name):
print("found_docs", found_docs)
return found_docs

def batch_vector_search(self, search_queries: List[str], course_name: str, top_n: int=20):
def batch_vector_search(self, search_queries: List[str], course_name: str, top_n: int=50):
from qdrant_client.http import models as rest
o = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)
# Prepare the filter for the course name
Expand Down Expand Up @@ -1042,230 +1042,6 @@ def batch_vector_search(self, search_queries: List[str], course_name: str, top_n

return found_docs

# def context_padding(self, found_docs, search_query, course_name):
# """
# Takes top N contexts acquired from QRANT similarity search and pads them
# with context from the original document from Supabase.
# 1. Use s3_path OR url as unique doc indentifier
# 2. Use s_path + chunk_index to locate chunk in the document.
# 3. Pad it with 3 contexts before and after it.
# 4. If chunk_index is not present, use page number to locate the page in the document.
# 5. Ensure no duplication takes place - top N will often have contexts belonging to the same doc.
# """
# print("inside context padding")
# documents_table = os.environ['NEW_NEW_NEWNEW_MATERIALS_SUPABASE_TABLE']
# retrieved_contexts_identifiers = {}
# result_contexts = []

# # only pad the first 5 contexts, append the rest as it is.
# for i, doc in enumerate(found_docs): # top N from QDRANT
# if i < 5:
# # if url present, query through that
# if 'url' in doc.metadata.keys() and doc.metadata['url']:
# parent_doc_id = doc.metadata['url']
# response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('url', parent_doc_id).execute()
# retrieved_contexts_identifiers[parent_doc_id] = []
# # else use s3_path
# else:
# parent_doc_id = doc.metadata['s3_path']
# response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('s3_path', parent_doc_id).execute()
# retrieved_contexts_identifiers[parent_doc_id] = []

# data = response.data # at this point, we have the origin parent document from Supabase
# if len(data) > 0:
# filename = data[0]['readable_filename']
# contexts = data[0]['contexts']
# print("no of contexts within the og doc: ", len(contexts))

# if 'chunk_index' in doc.metadata:
# # retrieve by chunk index --> pad contexts
# target_chunk_index = doc.metadata['chunk_index']
# print("target chunk_index: ", target_chunk_index)
# print("len of result contexts before chunk_index padding: ", len(result_contexts))

# for context in contexts:
# curr_chunk_index = context['chunk_index']
# # collect between range of target index - 3 and target index + 3
# if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3) and curr_chunk_index not in retrieved_contexts_identifiers[parent_doc_id]:
# context['readable_filename'] = filename
# context['course_name'] = course_name
# context['s3_path'] = data[0]['s3_path']
# context['url'] = data[0]['url']
# context['base_url'] = data[0]['base_url']

# result_contexts.append(context)
# # add current index to retrieved_contexts_identifiers after each context is retrieved to avoid duplicates
# retrieved_contexts_identifiers[parent_doc_id].append(curr_chunk_index)
# print("len of result contexts after chunk_index padding: ", len(result_contexts))

# elif doc.metadata['pagenumber'] != '':
# # retrieve by page number --> retrieve the single whole page?
# pagenumber = doc.metadata['pagenumber']
# print("target pagenumber: ", pagenumber)
# print("len of result contexts before pagenumber padding: ", len(result_contexts))
# for context in contexts:
# if int(context['pagenumber']) == pagenumber:
# context['readable_filename'] = filename
# context['course_name'] = course_name
# context['s3_path'] = data[0]['s3_path']
# context['url'] = data[0]['url']
# context['base_url'] = data[0]['base_url']
# result_contexts.append(context)

# print("len of result contexts after pagenumber padding: ", len(result_contexts))

# # add page number to retrieved_contexts_identifiers after all contexts belonging to that page number have been retrieved
# retrieved_contexts_identifiers[parent_doc_id].append(pagenumber)
# else:
# # dont pad, re-factor it to be like Supabase object
# print("no chunk index or page number found, just appending the QDRANT context")
# print("len of result contexts before qdrant append: ", len(result_contexts))
# context_dict = {'text': doc.page_content,
# 'embedding': '',
# 'pagenumber': doc.metadata['pagenumber'],
# 'readable_filename': doc.metadata['readable_filename'],
# 'course_name': course_name,
# 's3_path': doc.metadata['s3_path'],
# 'base_url':doc.metadata['base_url']
# }
# if 'url' in doc.metadata.keys():
# context_dict['url'] = doc.metadata['url']

# result_contexts.append(context_dict)
# print("len of result contexts after qdrant append: ", len(result_contexts))
# else:
# # append the rest of the docs as it is.
# print("reached > 5 docs, just appending the QDRANT context")
# context_dict = {'text': doc.page_content,
# 'embedding': '',
# 'pagenumber': doc.metadata['pagenumber'],
# 'readable_filename': doc.metadata['readable_filename'],
# 'course_name': course_name,
# 's3_path': doc.metadata['s3_path'],
# 'base_url':doc.metadata['base_url']
# }
# if 'url' in doc.metadata.keys():
# context_dict['url'] = doc.metadata['url']

# result_contexts.append(context_dict)


# print("length of final contexts: ", len(result_contexts))
# return result_contexts

def context_data_processing(self, doc, course_name, retrieved_contexts_identifiers, result_docs):
"""
Does context padding for given doc. Used with context_padding()
"""
print("in context data processing")
documents_table = os.environ['NEW_NEW_NEWNEW_MATERIALS_SUPABASE_TABLE']

# query by url or s3_path
if 'url' in doc.metadata.keys() and doc.metadata['url']:
parent_doc_id = doc.metadata['url']
response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('url', parent_doc_id).execute()
retrieved_contexts_identifiers[parent_doc_id] = []

else:
parent_doc_id = doc.metadata['s3_path']
response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('s3_path', parent_doc_id).execute()
retrieved_contexts_identifiers[parent_doc_id] = []

data = response.data

if len(data) > 0:
# do the padding
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
print("no of contexts within the og doc: ", len(contexts))

if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys():
# pad contexts by chunk index + 3 and - 3
target_chunk_index = doc.metadata['chunk_index']
for context in contexts:
curr_chunk_index = context['chunk_index']
if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3) and curr_chunk_index not in retrieved_contexts_identifiers[parent_doc_id]:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_docs.append(context)
# add current index to retrieved_contexts_identifiers after each context is retrieved to avoid duplicates
retrieved_contexts_identifiers[parent_doc_id].append(curr_chunk_index)

elif doc.metadata['pagenumber'] != '':
# pad contexts belonging to same page number
pagenumber = doc.metadata['pagenumber']

for context in contexts:
if int(context['pagenumber']) == pagenumber and pagenumber not in retrieved_contexts_identifiers[parent_doc_id]:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_docs.append(context)
# add page number to retrieved_contexts_identifiers after all contexts belonging to that page number have been retrieved
retrieved_contexts_identifiers[parent_doc_id].append(pagenumber)

else:
# refactor as a Supabase object and append
context_dict = {
'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path'],
'base_url':doc.metadata['base_url']
}
if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']
result_docs.append(context_dict)

return result_docs

# def qdrant_context_processing(doc, course_name, result_contexts):
# """
# Re-factor QDRANT objects into Supabase objects and append to result_docs
# """
# context_dict = {
# 'text': doc.page_content,
# 'embedding': '',
# 'pagenumber': doc.metadata['pagenumber'],
# 'readable_filename': doc.metadata['readable_filename'],
# 'course_name': course_name,
# 's3_path': doc.metadata['s3_path'],
# 'base_url': doc.metadata['base_url']
# }
# if 'url' in doc.metadata.keys():
# context_dict['url'] = doc.metadata['url']

# result_contexts.append(context_dict)
# return result_contexts

# def context_padding(self, found_docs, search_query, course_name):
# """
# Takes top N contexts acquired from QRANT similarity search and pads them
# """
# print("inside main context padding")
# context_ids = {}
# # for doc in found_docs[:5]:
# # self.context_data_processing(doc, course_name, context_ids, result_contexts)

# with Manager() as manager:
# result_contexts = manager.list()
# with ProcessPoolExecutor() as executor:
# partial_func = partial(qdrant_context_processing, course_name=course_name, result_contexts=result_contexts)
# results = executor.map(partial_func, found_docs[5:])

# results = list(results)
# print("RESULTS: ", results)

# return results


def reciprocal_rank_fusion(self, results: list[list], k=60):
"""
Since we have multiple queries, and n documents returned per query, we need to go through all the results
Expand Down Expand Up @@ -1376,6 +1152,10 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit

print(f"⏰ Multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds")

# Context filtering
#filtered_docs = list(run(contexts=found_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100))
#print(f"Number of docs after context filtering: {len(filtered_docs)}")

# 'context padding' // 'parent document retriever'
final_docs = context_processing(found_docs, search_query, course_name)
print(f"Number of final docs after context padding: {len(final_docs)}")
Expand All @@ -1384,10 +1164,13 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore

filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=30))
print(f"Number of docs after context filtering: {len(filtered_docs)}")

valid_docs = []
num_tokens = 0

for doc in final_docs:
for doc in filtered_docs:

doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n"
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore
Expand All @@ -1401,19 +1184,16 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
break

print("Length of valid docs: ", len(valid_docs))

# insert filtering here and only pass relevant contexts ahead
final_filtered_docs = list(run(contexts=valid_docs, user_query=search_query, max_time_before_return=45, max_concurrency=20))

print("Length of final filtered docs: ", len(final_filtered_docs))
print("FINAL FILTERED DOCS: ", final_filtered_docs)
# Context filtering
#filtered_docs = list(run(contexts=valid_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100))
#print(f"Number of docs after context filtering: {len(filtered_docs)}")

print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(final_filtered_docs)}")
print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}")
print(f"Course: {course_name} ||| search_query: {search_query}")
print(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds")
if len(final_filtered_docs) == 0:
if len(valid_docs) == 0:
return []
return self.format_for_json_mqr(final_filtered_docs)
return self.format_for_json_mqr(valid_docs)
except Exception as e:
# return full traceback to front end
err: str = f"ERROR: In /getTopContextsWithMQR. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.format_exc()}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore
Expand Down

0 comments on commit 46868aa

Please sign in to comment.