Skip to content

Commit

Permalink
Merge branch 'info-retrieval' into multi-query-retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox authored Nov 8, 2023
2 parents 88004f3 + 6781977 commit e080d0a
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,8 +947,7 @@ def getAll(
return distinct_dicts

def vector_search(self, search_query, course_name):
#top_n = 80
top_n = 5
top_n = 80
o = OpenAIEmbeddings() # type: ignore
user_query_embedding = o.embed_query(search_query)
myfilter = models.Filter(
Expand Down Expand Up @@ -1051,7 +1050,7 @@ def context_padding(self, found_docs, search_query, course_name):
retrieved_contexts_identifiers = {}
result_contexts = []
for doc in found_docs: # top N from QDRANT

# if url present, query through that
if 'url' in doc.metadata.keys() and doc.metadata['url']:
parent_doc_id = doc.metadata['url']
Expand All @@ -1076,6 +1075,7 @@ def context_padding(self, found_docs, search_query, course_name):
target_chunk_index = doc.metadata['chunk_index']
print("target chunk_index: ", target_chunk_index)
print("len of result contexts before chunk_index padding: ", len(result_contexts))

for context in contexts:
curr_chunk_index = context['chunk_index']
# collect between range of target index - 3 and target index + 3
Expand All @@ -1089,12 +1089,14 @@ def context_padding(self, found_docs, search_query, course_name):
result_contexts.append(context)
# add current index to retrieved_contexts_identifiers after each context is retrieved to avoid duplicates
retrieved_contexts_identifiers[parent_doc_id].append(curr_chunk_index)

print("len of result contexts after chunk_index padding: ", len(result_contexts))

elif doc.metadata['pagenumber'] != '':
# retrieve by page number --> retrieve the single whole page?
pagenumber = doc.metadata['pagenumber']
print("target pagenumber: ", pagenumber)

print("len of result contexts: ", len(result_contexts))
for context in contexts:
if int(context['pagenumber']) == pagenumber:
Expand All @@ -1104,12 +1106,15 @@ def context_padding(self, found_docs, search_query, course_name):
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_contexts.append(context)

print("len of result contexts after pagenumber padding: ", len(result_contexts))

# add page number to retrieved_contexts_identifiers after all contexts belonging to that page number have been retrieved
retrieved_contexts_identifiers[parent_doc_id].append(pagenumber)
else:
# dont pad, re-factor it to be like Supabase object
print("no chunk index or page number, just appending the QDRANT context")

print("len of result contexts in else condn: ", len(result_contexts))
context_dict = {'text': doc.page_content,
'embedding': '',
Expand All @@ -1124,7 +1129,7 @@ def context_padding(self, found_docs, search_query, course_name):

result_contexts.append(context_dict)
print("len of result contexts after qdrant append: ", len(result_contexts))

print("length of final contexts: ", len(result_contexts))

return result_contexts
Expand Down Expand Up @@ -1369,9 +1374,9 @@ def get_stuffed_prompt(self, search_query: str, course_name: str, token_limit: i
token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore
valid_docs = []
for d in found_docs:
if "pagenumber" not in d.payload["metadata"].keys(): # type: ignore
d.payload["metadata"]["pagenumber"] = d.payload["metadata"]["pagenumber_or_timestamp"] # type: ignore
doc_string = f"---\nDocument: {d.payload['metadata']['readable_filename']}{', page: ' + str(d.payload['metadata']['pagenumber']) if d.payload['metadata']['pagenumber'] else ''}\n{d.payload.get('page_content')}\n" # type: ignore
if "pagenumber" not in d.payload.keys(): # type: ignore
d.payload["pagenumber"] = d.payload["pagenumber_or_timestamp"] # type: ignore
doc_string = f"---\nDocument: {d.payload['readable_filename']}{', page: ' + str(d.payload['pagenumber']) if d.payload['pagenumber'] else ''}\n{d.payload.get('page_content')}\n" # type: ignore
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore

print(f"Page: {d.payload.get('page_content')[:100]}...") # type: ignore
Expand Down

0 comments on commit e080d0a

Please sign in to comment.