Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding parent document retrieval in default RAG pipeline #233

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def getAllMaterialsForCourse(self, course_name: str):
'course_name', course_name).execute()

def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str):
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq(
's3_path', s3_path).eq('course_name', course_name).execute()

def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str):
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq(
key, value).eq('course_name', course_name).execute()

def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str):
Expand Down
7 changes: 7 additions & 0 deletions ai_ta_backend/executors/process_pool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ def submit(self, fn, *args, **kwargs):

def map(self, fn, *iterables, timeout=None, chunksize=1):
return self.executor.map(fn, *iterables, timeout=timeout, chunksize=chunksize)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.executor.shutdown(wait=True)

1 change: 1 addition & 0 deletions ai_ta_backend/service/export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''):
"""

response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'documents')
print("response count: ", response.count)
# add a condition to route to direct download or s3 download
if response.count > 500:
# call background task to upload to s3
Expand Down
255 changes: 249 additions & 6 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import time
import traceback
from typing import Dict, List, Union
from functools import partial
from multiprocessing import Manager
from multiprocessing import Lock


import openai
from injector import inject
Expand All @@ -17,7 +21,10 @@
from ai_ta_backend.service.posthog_service import PosthogService
from ai_ta_backend.service.sentry_service import SentryService
from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost

#from ai_ta_backend.utils.context_parent_doc_padding import context_parent_doc_padding
from ai_ta_backend.executors.process_pool_executor import ProcessPoolExecutorAdapter
from functools import partial
from multiprocessing import Manager

class RetrievalService:
"""
Expand All @@ -26,13 +33,14 @@ class RetrievalService:

@inject
def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService,
sentry: SentryService, nomicService: NomicService):
sentry: SentryService, nomicService: NomicService, executor: ProcessPoolExecutorAdapter):
self.vdb = vdb
self.sqlDb = sqlDb
self.aws = aws
self.sentry = sentry
self.posthog = posthog
self.nomicService = nomicService
self.executor = executor

openai.api_key = os.environ["OPENAI_API_KEY"]

Expand Down Expand Up @@ -77,19 +85,24 @@ def getTopContexts(self,
course_name=course_name,
doc_groups=doc_groups)

# add parent doc retrieval here
print(f"Number of docs retrieved: {len(found_docs)}")
parent_docs = self.context_parent_doc_padding(found_docs, course_name)
print(f"Number of final docs after context padding: {len(parent_docs)}")

pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n"
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + "\n\nNow please respond to my query: " + # type: ignore
search_query)

valid_docs = []
num_tokens = 0
for doc in found_docs:
doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n"
for doc in parent_docs:
doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n"
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore

print(
f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc.metadata['readable_filename']}"
f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc['readable_filename']}"
)
if token_counter + num_tokens <= token_limit:
token_counter += num_tokens
Expand Down Expand Up @@ -117,7 +130,7 @@ def getTopContexts(self,
},
)

return self.format_for_json(valid_docs)
return self.format_for_json_mqr(valid_docs)
except Exception as e:
# return full traceback to front end
# err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore
Expand Down Expand Up @@ -427,3 +440,233 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]:
]

return contexts

def context_parent_doc_padding(self, found_docs, course_name):
"""
Takes top N contexts acquired from QRANT similarity search and pads them
"""
print("inside main context padding")
start_time = time.monotonic()

with Manager() as manager:
qdrant_contexts = manager.list()
supabase_contexts = manager.list()
partial_func1 = partial(qdrant_context_processing, course_name=course_name, result_contexts=qdrant_contexts)
partial_func2 = partial(supabase_context_padding, course_name=course_name, result_docs=supabase_contexts)

with self.executor as executor:
executor.map(partial_func1, found_docs[5:])
executor.map(partial_func2, found_docs[:5])

# with self.executor as executor:
# executor.map(lambda doc: self.qdrant_context_processing(doc, course_name=course_name, result_contexts=[]), found_docs[5:])
# executor.map(lambda doc: self.supabase_context_padding(doc, course_name=course_name, result_docs=[]), found_docs[:5])


supabase_contexts_no_duplicates = []
for context in supabase_contexts:
if context not in supabase_contexts_no_duplicates:
supabase_contexts_no_duplicates.append(context)

result_contexts = supabase_contexts_no_duplicates + list(qdrant_contexts)
#print("len of supabase contexts: ", len(supabase_contexts_no_duplicates))

print(f"⏰ Context padding runtime: {(time.monotonic() - start_time):.2f} seconds")

return result_contexts

# def qdrant_context_processing(self, doc, course_name, result_contexts):
# """
# Re-factor QDRANT objects into Supabase objects and append to result_docs
# """
# print("inside qdrant context processing")
# context_dict = {
# 'text': doc.page_content,
# 'embedding': '',
# 'pagenumber': doc.metadata['pagenumber'],
# 'readable_filename': doc.metadata['readable_filename'],
# 'course_name': course_name,
# 's3_path': doc.metadata['s3_path']
# }

# if 'url' in doc.metadata.keys():
# context_dict['url'] = doc.metadata['url']
# else:
# context_dict['url'] = ''

# if 'base_url' in doc.metadata.keys():
# context_dict['base_url'] = doc.metadata['url']
# else:
# context_dict['base_url'] = ''

# result_contexts.append(context_dict)

# def supabase_context_padding(self, doc, course_name, result_docs):
# """
# Does context padding for given doc.
# """
# print("inside supabase context padding")
# SQL_DB = SQLDatabase()

# # query by url or s3_path
# if 'url' in doc.metadata.keys() and doc.metadata['url']:
# parent_doc_id = doc.metadata['url']
# response = SQL_DB.getMaterialsForCourseAndKeyAndValue(course_name=course_name, key='url', value=parent_doc_id)
# else:
# parent_doc_id = doc.metadata['s3_path']
# response = SQL_DB.getMaterialsForCourseAndS3Path(course_name=course_name, s3_path=parent_doc_id)

# data = response.data

# if len(data) > 0:
# # do the padding
# filename = data[0]['readable_filename']
# contexts = data[0]['contexts']
# #print("no of contexts within the og doc: ", len(contexts))

# if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys():
# #print("inside chunk index")
# # pad contexts by chunk index + 3 and - 3
# target_chunk_index = doc.metadata['chunk_index']
# for context in contexts:
# curr_chunk_index = context['chunk_index']
# if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3):
# context['readable_filename'] = filename
# context['course_name'] = course_name
# context['s3_path'] = data[0]['s3_path']
# context['url'] = data[0]['url']
# context['base_url'] = data[0]['base_url']
# result_docs.append(context)

# elif doc.metadata['pagenumber'] != '':
# #print("inside page number")
# # pad contexts belonging to same page number
# pagenumber = doc.metadata['pagenumber']

# for context in contexts:
# # pad contexts belonging to same page number
# if int(context['pagenumber']) == pagenumber:
# context['readable_filename'] = filename
# context['course_name'] = course_name
# context['s3_path'] = data[0]['s3_path']
# context['url'] = data[0]['url']
# context['base_url'] = data[0]['base_url']
# result_docs.append(context)

# else:
# #print("inside else")
# # refactor as a Supabase object and append
# context_dict = {
# 'text': doc.page_content,
# 'embedding': '',
# 'pagenumber': doc.metadata['pagenumber'],
# 'readable_filename': doc.metadata['readable_filename'],
# 'course_name': course_name,
# 's3_path': doc.metadata['s3_path'],
# 'base_url': doc.metadata['base_url']
# }
# if 'url' in doc.metadata.keys():
# context_dict['url'] = doc.metadata['url']
# else:
# context_dict['url'] = ''

# result_docs.append(context_dict)

def qdrant_context_processing(doc, course_name, result_contexts):
"""
Re-factor QDRANT objects into Supabase objects and append to result_docs
"""
#print("inside qdrant context processing")
context_dict = {
'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path']
}

if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']
else:
context_dict['url'] = ''

if 'base_url' in doc.metadata.keys():
context_dict['base_url'] = doc.metadata['url']
else:
context_dict['base_url'] = ''

result_contexts.append(context_dict)

def supabase_context_padding(doc, course_name, result_docs):
"""
Does context padding for given doc.
"""
#print("inside supabase context padding")
SQL_DB = SQLDatabase()

# query by url or s3_path
if 'url' in doc.metadata.keys() and doc.metadata['url']:
parent_doc_id = doc.metadata['url']
response = SQL_DB.getMaterialsForCourseAndKeyAndValue(course_name=course_name, key='url', value=parent_doc_id)
else:
parent_doc_id = doc.metadata['s3_path']
response = SQL_DB.getMaterialsForCourseAndS3Path(course_name=course_name, s3_path=parent_doc_id)

data = response.data

if len(data) > 0:
# do the padding
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
#print("no of contexts within the og doc: ", len(contexts))

if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys():
#print("inside chunk index")
# pad contexts by chunk index + 3 and - 3
target_chunk_index = doc.metadata['chunk_index']
for context in contexts:
curr_chunk_index = context['chunk_index']
if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3):
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_docs.append(context)

elif doc.metadata['pagenumber'] != '':
#print("inside page number")
# pad contexts belonging to same page number
pagenumber = doc.metadata['pagenumber']

for context in contexts:
# pad contexts belonging to same page number
if int(context['pagenumber']) == pagenumber:
context['readable_filename'] = filename
context['course_name'] = course_name
context['s3_path'] = data[0]['s3_path']
context['url'] = data[0]['url']
context['base_url'] = data[0]['base_url']
result_docs.append(context)

else:
#print("inside else")
# refactor as a Supabase object and append
context_dict = {
'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path'],
'base_url': doc.metadata['base_url']
}
if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']
else:
context_dict['url'] = ''

result_docs.append(context_dict)


Loading