Skip to content

Commit

Permalink
added context padding in getTopContexts()
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Mar 13, 2024
1 parent 46d5a5e commit 8dba0b1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
6 changes: 3 additions & 3 deletions ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class SQLDatabase:

@inject
def __init__(self, db_url: str):
def __init__(self):
# Create a Supabase client
self.supabase_client = supabase.create_client( # type: ignore
supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY'])
Expand All @@ -18,11 +18,11 @@ def getAllMaterialsForCourse(self, course_name: str):
'course_name', course_name).execute()

def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str):
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq(
's3_path', s3_path).eq('course_name', course_name).execute()

def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str):
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq(
key, value).eq('course_name', course_name).execute()

def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str):
Expand Down
7 changes: 4 additions & 3 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name)

# add parent doc retrieval here
print(f"Number of docs retrieved: {len(found_docs)}")
parent_docs = context_parent_doc_padding(found_docs, search_query, course_name)
print(f"Number of final docs after context padding: {len(parent_docs)}")

Expand All @@ -82,11 +83,11 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
valid_docs = []
num_tokens = 0
for doc in parent_docs:
doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n"
doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n"
num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore

print(
f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc.metadata['readable_filename']}"
f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc['readable_filename']}"
)
if token_counter + num_tokens <= token_limit:
token_counter += num_tokens
Expand Down Expand Up @@ -114,7 +115,7 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int =
},
)

return self.format_for_json(valid_docs)
return self.format_for_json_mqr(valid_docs)
except Exception as e:
# return full traceback to front end
# err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore
Expand Down
38 changes: 22 additions & 16 deletions ai_ta_backend/utils/context_parent_doc_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
# SUPABASE_CLIENT = supabase.create_client(supabase_url=os.environ['SUPABASE_URL'],
# supabase_key=os.environ['SUPABASE_API_KEY']) # type: ignore

SQL_DB = SQLDatabase
SQL_DB = SQLDatabase()

def context_parent_doc_padding(found_docs, search_query, course_name):
"""
Takes top N contexts acquired from QRANT similarity search and pads them
"""
print("inside main context padding")
start_time = time.monotonic()

with Manager() as manager:
qdrant_contexts = manager.list()
supabase_contexts = manager.list()
Expand All @@ -44,50 +44,54 @@ def context_parent_doc_padding(found_docs, search_query, course_name):
def qdrant_context_processing(doc, course_name, result_contexts):
"""
Re-factor QDRANT objects into Supabase objects and append to result_docs
"""
"""
context_dict = {
'text': doc.page_content,
'embedding': '',
'pagenumber': doc.metadata['pagenumber'],
'readable_filename': doc.metadata['readable_filename'],
'course_name': course_name,
's3_path': doc.metadata['s3_path'],
'base_url': doc.metadata['base_url']
's3_path': doc.metadata['s3_path']
}
if 'url' in doc.metadata.keys():
context_dict['url'] = doc.metadata['url']
else:
context_dict['url'] = ''

if 'base_url' in doc.metadata.keys():
context_dict['base_url'] = doc.metadata['url']
else:
context_dict['base_url'] = ''

result_contexts.append(context_dict)
return result_contexts
#return result_contexts


def supabase_context_padding(doc, course_name, result_docs):
"""
Does context padding for given doc.
"""

# query by url or s3_path
if 'url' in doc.metadata.keys() and doc.metadata['url']:
parent_doc_id = doc.metadata['url']
response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name',
course_name).eq('url', parent_doc_id).execute()

# response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name',
# course_name).eq('url', parent_doc_id).execute()
response = SQL_DB.getMaterialsForCourseAndKeyAndValue(course_name=course_name, key='url', value=parent_doc_id)
else:
parent_doc_id = doc.metadata['s3_path']
response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name',
course_name).eq('s3_path',
parent_doc_id).execute()

# response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name',
# course_name).eq('s3_path',
# parent_doc_id).execute()
response = SQL_DB.getMaterialsForCourseAndS3Path(course_name=course_name, s3_path=parent_doc_id)
data = response.data

if len(data) > 0:
# do the padding
filename = data[0]['readable_filename']
contexts = data[0]['contexts']
#print("no of contexts within the og doc: ", len(contexts))

if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys():
#print("inside chunk index")
# pad contexts by chunk index + 3 and - 3
Expand Down Expand Up @@ -135,3 +139,5 @@ def supabase_context_padding(doc, course_name, result_docs):
context_dict['url'] = ''

result_docs.append(context_dict)


0 comments on commit 8dba0b1

Please sign in to comment.