From 8dba0b1a3d6eec5b2f1453f8cc0510586f4be961 Mon Sep 17 00:00:00 2001 From: star-nox Date: Wed, 13 Mar 2024 17:43:47 -0500 Subject: [PATCH] added context padding in getTopContexts() --- ai_ta_backend/database/sql.py | 6 +-- ai_ta_backend/service/retrieval_service.py | 7 ++-- .../utils/context_parent_doc_padding.py | 38 +++++++++++-------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index a9819657..b0ffb455 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -7,7 +7,7 @@ class SQLDatabase: @inject - def __init__(self, db_url: str): + def __init__(self): # Create a Supabase client self.supabase_client = supabase.create_client( # type: ignore supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) @@ -18,11 +18,11 @@ def getAllMaterialsForCourse(self, course_name: str): 'course_name', course_name).execute() def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq( 's3_path', s3_path).eq('course_name', course_name).execute() def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq( key, value).eq('course_name', course_name).execute() def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 266f4a01..95d8d6aa 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -71,6 +71,7 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name) # add parent doc retrieval here + print(f"Number of docs retrieved: {len(found_docs)}") parent_docs = context_parent_doc_padding(found_docs, search_query, course_name) print(f"Number of final docs after context padding: {len(parent_docs)}") @@ -82,11 +83,11 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = valid_docs = [] num_tokens = 0 for doc in parent_docs: - doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n" + doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n" num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore print( - f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc.metadata['readable_filename']}" + f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc['readable_filename']}" ) if token_counter + num_tokens <= token_limit: token_counter += num_tokens @@ -114,7 +115,7 @@ def getTopContexts(self, search_query: str, course_name: str, token_limit: int = }, ) - return self.format_for_json(valid_docs) + return self.format_for_json_mqr(valid_docs) except Exception as e: # return full traceback to front end # err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore diff --git a/ai_ta_backend/utils/context_parent_doc_padding.py b/ai_ta_backend/utils/context_parent_doc_padding.py index 1ce559d5..892cddc9 100644 --- a/ai_ta_backend/utils/context_parent_doc_padding.py +++ b/ai_ta_backend/utils/context_parent_doc_padding.py @@ -10,7 +10,7 @@ # SUPABASE_CLIENT = supabase.create_client(supabase_url=os.environ['SUPABASE_URL'], # supabase_key=os.environ['SUPABASE_API_KEY']) # type: ignore -SQL_DB = SQLDatabase +SQL_DB = SQLDatabase() def context_parent_doc_padding(found_docs, search_query, course_name): """ @@ -18,7 +18,7 @@ def context_parent_doc_padding(found_docs, search_query, course_name): """ print("inside main context padding") start_time = time.monotonic() - + with Manager() as manager: qdrant_contexts = manager.list() supabase_contexts = manager.list() @@ -44,50 +44,54 @@ def context_parent_doc_padding(found_docs, search_query, course_name): def qdrant_context_processing(doc, course_name, result_contexts): """ Re-factor QDRANT objects into Supabase objects and append to result_docs - """ + """ context_dict = { 'text': doc.page_content, 'embedding': '', 'pagenumber': doc.metadata['pagenumber'], 'readable_filename': doc.metadata['readable_filename'], 'course_name': course_name, - 's3_path': doc.metadata['s3_path'], - 'base_url': doc.metadata['base_url'] + 's3_path': doc.metadata['s3_path'] } if 'url' in doc.metadata.keys(): context_dict['url'] = doc.metadata['url'] else: context_dict['url'] = '' + if 'base_url' in doc.metadata.keys(): + context_dict['base_url'] = doc.metadata['url'] + else: + context_dict['base_url'] = '' + result_contexts.append(context_dict) - return result_contexts + #return result_contexts def supabase_context_padding(doc, course_name, result_docs): """ Does context padding for given doc. """ - + # query by url or s3_path if 'url' in doc.metadata.keys() and doc.metadata['url']: parent_doc_id = doc.metadata['url'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('url', parent_doc_id).execute() - + # response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', + # course_name).eq('url', parent_doc_id).execute() + response = SQL_DB.getMaterialsForCourseAndKeyAndValue(course_name=course_name, key='url', value=parent_doc_id) else: parent_doc_id = doc.metadata['s3_path'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('s3_path', - parent_doc_id).execute() - + # response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', + # course_name).eq('s3_path', + # parent_doc_id).execute() + response = SQL_DB.getMaterialsForCourseAndS3Path(course_name=course_name, s3_path=parent_doc_id) data = response.data - + if len(data) > 0: # do the padding filename = data[0]['readable_filename'] contexts = data[0]['contexts'] #print("no of contexts within the og doc: ", len(contexts)) - + if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys(): #print("inside chunk index") # pad contexts by chunk index + 3 and - 3 @@ -135,3 +139,5 @@ def supabase_context_padding(doc, course_name, result_docs): context_dict['url'] = '' result_docs.append(context_dict) + +