Skip to content

Commit

Permalink
modified function to only pad first 5 docs
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Nov 8, 2023
1 parent 9eaf370 commit 8788533
Showing 1 changed file with 95 additions and 81 deletions.
176 changes: 95 additions & 81 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from langchain.document_loaders.image import UnstructuredImageLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.openai import OpenAI
from langchain.load import dumps, loads
from langchain.load import loads, dumps
from langchain.schema import Document
from langchain.schema.output_parser import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
Expand Down Expand Up @@ -1048,87 +1048,101 @@ def context_padding(self, found_docs, search_query, course_name):
documents_table = os.environ['NEW_NEW_NEWNEW_MATERIALS_SUPABASE_TABLE']
retrieved_contexts_identifiers = {}
result_contexts = []
for doc in found_docs: # top N from QDRANT

# if url present, query through that
if 'url' in doc.metadata.keys() and doc.metadata['url']:
parent_doc_id = doc.metadata['url']
response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('url', parent_doc_id).execute()
retrieved_contexts_identifiers[parent_doc_id] = []
# else use s3_path
else:
parent_doc_id = doc.metadata['s3_path']
response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('s3_path', parent_doc_id).execute()
retrieved_contexts_identifiers[parent_doc_id] = []

data = response.data # at this point, we have the origin parent document from Supabase
if len(data) > 0:
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
print("no of contexts within the og doc: ", len(contexts))

if 'chunk_index' in doc.metadata:
# retrieve by chunk index --> pad contexts
target_chunk_index = doc.metadata['chunk_index']
print("target chunk_index: ", target_chunk_index)
print("len of result contexts before chunk_index padding: ", len(result_contexts))

for context in contexts:
curr_chunk_index = context['chunk_index']
# collect between range of target index - 3 and target index + 3
if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3) and curr_chunk_index not in retrieved_contexts_identifiers[parent_doc_id]:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']

result_contexts.append(context)
# add current index to retrieved_contexts_identifiers after each context is retrieved to avoid duplicates
retrieved_contexts_identifiers[parent_doc_id].append(curr_chunk_index)

print("len of result contexts after chunk_index padding: ", len(result_contexts))

elif doc.metadata['pagenumber'] != '':
# retrieve by page number --> retrieve the single whole page?
pagenumber = doc.metadata['pagenumber']
print("target pagenumber: ", pagenumber)

print("len of result contexts: ", len(result_contexts))
for context in contexts:
if int(context['pagenumber']) == pagenumber:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_contexts.append(context)

print("len of result contexts after pagenumber padding: ", len(result_contexts))

# add page number to retrieved_contexts_identifiers after all contexts belonging to that page number have been retrieved
retrieved_contexts_identifiers[parent_doc_id].append(pagenumber)
else:
# dont pad, re-factor it to be like Supabase object
print("no chunk index or page number, just appending the QDRANT context")

print("len of result contexts in else condn: ", len(result_contexts))
context_dict = {'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path'],
'base_url':doc.metadata['base_url']
}
if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']

result_contexts.append(context_dict)
print("len of result contexts after qdrant append: ", len(result_contexts))

print("length of final contexts: ", len(result_contexts))

# only pad the first 5 contexts, append the rest as it is.
for i, doc in enumerate(found_docs): # top N from QDRANT
if i < 5:
# if url present, query through that
if 'url' in doc.metadata.keys() and doc.metadata['url']:
parent_doc_id = doc.metadata['url']
response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('url', parent_doc_id).execute()
retrieved_contexts_identifiers[parent_doc_id] = []
# else use s3_path
else:
parent_doc_id = doc.metadata['s3_path']
response = self.supabase_client.table(documents_table).select('*').eq('course_name', course_name).eq('s3_path', parent_doc_id).execute()
retrieved_contexts_identifiers[parent_doc_id] = []

data = response.data # at this point, we have the origin parent document from Supabase
if len(data) > 0:
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
print("no of contexts within the og doc: ", len(contexts))

if 'chunk_index' in doc.metadata:
# retrieve by chunk index --> pad contexts
target_chunk_index = doc.metadata['chunk_index']
print("target chunk_index: ", target_chunk_index)
print("len of result contexts before chunk_index padding: ", len(result_contexts))

for context in contexts:
curr_chunk_index = context['chunk_index']
# collect between range of target index - 3 and target index + 3
if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3) and curr_chunk_index not in retrieved_contexts_identifiers[parent_doc_id]:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']

result_contexts.append(context)
# add current index to retrieved_contexts_identifiers after each context is retrieved to avoid duplicates
retrieved_contexts_identifiers[parent_doc_id].append(curr_chunk_index)
print("len of result contexts after chunk_index padding: ", len(result_contexts))

elif doc.metadata['pagenumber'] != '':
# retrieve by page number --> retrieve the single whole page?
pagenumber = doc.metadata['pagenumber']
print("target pagenumber: ", pagenumber)
print("len of result contexts before pagenumber padding: ", len(result_contexts))
for context in contexts:
if int(context['pagenumber']) == pagenumber:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_contexts.append(context)

print("len of result contexts after pagenumber padding: ", len(result_contexts))

# add page number to retrieved_contexts_identifiers after all contexts belonging to that page number have been retrieved
retrieved_contexts_identifiers[parent_doc_id].append(pagenumber)
else:
# dont pad, re-factor it to be like Supabase object
print("no chunk index or page number found, just appending the QDRANT context")
print("len of result contexts before qdrant append: ", len(result_contexts))
context_dict = {'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path'],
'base_url':doc.metadata['base_url']
}
if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']

result_contexts.append(context_dict)
print("len of result contexts after qdrant append: ", len(result_contexts))
else:
# append the rest of the docs as it is.
print("reached > 5 docs, just appending the QDRANT context")
context_dict = {'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path'],
'base_url':doc.metadata['base_url']
}
if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']

result_contexts.append(context_dict)


print("length of final contexts: ", len(result_contexts))
return result_contexts

def reciprocal_rank_fusion(self, results: list[list], k=60):
Expand Down

0 comments on commit 8788533

Please sign in to comment.