diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index caf0ac51..f5739fef 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -18,11 +18,11 @@ def getAllMaterialsForCourse(self, course_name: str): 'course_name', course_name).execute() def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq( 's3_path', s3_path).eq('course_name', course_name).execute() def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq( key, value).eq('course_name', course_name).execute() def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): @@ -110,3 +110,8 @@ def updateProjects(self, course_name: str, data: dict): def getConversation(self, course_name: str, key: str, value: str): return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute() + def getDocsByURLs(self, course_name: str, urls: list): + return self.supabase_client.table("documents").select("*").eq("course_name", course_name).in_("url", urls).execute() + + def getDocsByS3Paths(self, course_name: str, s3_paths: list): + return self.supabase_client.table("documents").select("*").eq("course_name", course_name).in_("s3_path", s3_paths).execute() \ No newline at end of file diff --git a/ai_ta_backend/executors/process_pool_executor.py b/ai_ta_backend/executors/process_pool_executor.py index 81b4860c..b981d613 100644 --- a/ai_ta_backend/executors/process_pool_executor.py +++ b/ai_ta_backend/executors/process_pool_executor.py @@ -29,3 +29,10 @@ def submit(self, fn, *args, **kwargs): def map(self, fn, *iterables, timeout=None, chunksize=1): return self.executor.map(fn, *iterables, timeout=timeout, chunksize=chunksize) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.executor.shutdown(wait=True) + diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 5b52e7d3..612d4d85 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -125,6 +125,64 @@ def getTopContexts(service: RetrievalService) -> Response: return response +@app.route('/getTopContextsv2', methods=['POST']) +def getTopContextsv2(service: RetrievalService) -> Response: + """Get most relevant contexts for a given search query. + + Return value + + ## POST body + course name (optional) str + A json response with TBD fields. + search_query + token_limit + doc_groups + + Returns + ------- + JSON + A json response with TBD fields. + Metadata fields + * pagenumber_or_timestamp + * readable_filename + * s3_pdf_path + + Example: + [ + { + 'readable_filename': 'Lumetta_notes', + 'pagenumber_or_timestamp': 'pg. 19', + 's3_pdf_path': '/courses//Lumetta_notes.pdf', + 'text': 'In FSM, we do this...' + }, + ] + + Raises + ------ + Exception + Testing how exceptions are handled. + """ + data = request.get_json() + search_query: str = data.get('search_query', '') + course_name: str = data.get('course_name', '') + token_limit: int = data.get('token_limit', 3000) + doc_groups: List[str] = data.get('doc_groups', []) + + if search_query == '' or course_name == '': + # proper web error "400 Bad request" + abort( + 400, + description= + f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`" + ) + + found_documents = service.getTopContextsv2(search_query, course_name, token_limit, doc_groups) + + response = jsonify(found_documents) + response.headers.add('Access-Control-Allow-Origin', '*') + return response + + @app.route('/getAll', methods=['GET']) def getAll(service: RetrievalService) -> Response: """Get all course materials based on the course_name diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index 85f01118..51b65953 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -34,6 +34,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): """ response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'documents') + print("response count: ", response.count) # add a condition to route to direct download or s3 download if response.count > 500: # call background task to upload to s3 diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index c53bcefb..4c3d1549 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -3,6 +3,10 @@ import time import traceback from typing import Dict, List, Union +from functools import partial +from multiprocessing import Manager +from multiprocessing import Lock + import openai from injector import inject @@ -17,7 +21,9 @@ from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost - +from ai_ta_backend.executors.process_pool_executor import ProcessPoolExecutorAdapter +from functools import partial +from multiprocessing import Manager class RetrievalService: """ @@ -26,13 +32,14 @@ class RetrievalService: @inject def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService, - sentry: SentryService, nomicService: NomicService): + sentry: SentryService, nomicService: NomicService, executor: ProcessPoolExecutorAdapter): self.vdb = vdb self.sqlDb = sqlDb self.aws = aws self.sentry = sentry self.posthog = posthog self.nomicService = nomicService + self.executor = executor openai.api_key = os.environ["OPENAI_API_KEY"] @@ -51,8 +58,8 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos openai_api_key=os.environ["AZURE_OPENAI_KEY"], openai_api_version=os.environ["OPENAI_API_VERSION"], openai_api_type=os.environ['OPENAI_API_TYPE'], - ) - + ) # type: ignore + def getTopContexts(self, search_query: str, course_name: str, @@ -127,6 +134,86 @@ def getTopContexts(self, self.sentry.capture_exception(e) return err + + def getTopContextsv2(self, + search_query: str, + course_name: str, + token_limit: int = 4_000, + doc_groups: List[str] | None = None) -> Union[List[Dict], str]: + """Here's a summary of the work. + + /GET arguments + course name (optional) str: A json response with TBD fields. + + Returns + JSON: A json response with TBD fields. See main.py:getTopContexts docs. + or + String: An error message with traceback. + """ + if doc_groups is None: + doc_groups = [] + try: + start_time_overall = time.monotonic() + + found_docs: list[Document] = self.vector_search(search_query=search_query, + course_name=course_name, + doc_groups=doc_groups) + + # add parent doc retrieval here + print(f"Number of docs retrieved: {len(found_docs)}") + parent_docs = self.context_parent_doc_padding(found_docs, course_name) + print(f"Number of final docs after context padding: {len(parent_docs)}") + + pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" + # count tokens at start and end, then also count each context. + token_counter, _ = count_tokens_and_cost(pre_prompt + "\n\nNow please respond to my query: " + # type: ignore + search_query) + + valid_docs = [] + num_tokens = 0 + for doc in parent_docs: + doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n" + num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore + + print( + f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc['readable_filename']}" + ) + if token_counter + num_tokens <= token_limit: + token_counter += num_tokens + valid_docs.append(doc) + else: + # filled our token size, time to return + break + + print(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved") + print(f"Course: {course_name} ||| search_query: {search_query}") + print(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds") + if len(valid_docs) == 0: + return [] + + self.posthog.capture( + event_name="getTopContexts_success_DI", + properties={ + "user_query": search_query, + "course_name": course_name, + "token_limit": token_limit, + "total_tokens_used": token_counter, + "total_contexts_used": len(valid_docs), + "total_unique_docs_retrieved": len(found_docs), + "getTopContext_total_latency_sec": time.monotonic() - start_time_overall, + }, + ) + + return valid_docs + except Exception as e: + # return full traceback to front end + # err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore + err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.print_exc} \n{e}" # type: ignore + traceback.print_exc() + print(err) + self.sentry.capture_exception(e) + return err + def getAll( self, course_name: str, @@ -298,28 +385,28 @@ def getTopContextsWithMQR(self, # sentry_sdk.capture_exception(e) # return err - def format_for_json_mqr(self, found_docs) -> List[Dict]: - """ - Same as format_for_json, but for the new MQR pipeline. - """ - for found_doc in found_docs: - if "pagenumber" not in found_doc.keys(): - print("found no pagenumber") - found_doc['pagenumber'] = found_doc['pagenumber_or_timestamp'] - - contexts = [ - { - 'text': doc['text'], - 'readable_filename': doc['readable_filename'], - 'course_name ': doc['course_name'], - 's3_path': doc['s3_path'], - 'pagenumber': doc['pagenumber'], - 'url': doc['url'], # wouldn't this error out? - 'base_url': doc['base_url'], - } for doc in found_docs - ] - - return contexts + # def format_for_json_mqr(self, found_docs) -> List[Dict]: + # """ + # Same as format_for_json, but for the new MQR pipeline. + # """ + # for found_doc in found_docs: + # if "pagenumber" not in found_doc.keys(): + # print("found no pagenumber") + # found_doc['pagenumber'] = found_doc['pagenumber_or_timestamp'] + + # contexts = [ + # { + # 'text': doc['text'], + # 'readable_filename': doc['readable_filename'], + # 'course_name ': doc['course_name'], + # 's3_path': doc['s3_path'], + # 'pagenumber': doc['pagenumber'], + # 'url': doc['url'], # wouldn't this error out? + # 'base_url': doc['base_url'], + # } for doc in found_docs + # ] + + # return contexts def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, identifier_value: str): try: @@ -466,3 +553,113 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: ] return contexts + + def context_parent_doc_padding(self, found_docs, course_name): + """ + Takes top N contexts acquired from QRANT similarity search and pads them + """ + print("inside context_parent_doc_padding()") + start_time = time.monotonic() + + # form a list of urls and s3_paths + urls = [] + s3_paths = [] + for doc in found_docs[:5]: + if 'url' in doc.metadata.keys(): + urls.append(doc.metadata['url']) + else: + s3_paths.append(doc.metadata['s3_path']) + + # query Supabase + supabase_url_content = self.sqlDb.getDocsByURLs(course_name, urls).data if urls else [] + supabase_s3_content = self.sqlDb.getDocsByS3Paths(course_name, s3_paths).data if s3_paths else [] + supabase_data = supabase_url_content + supabase_s3_content + + with Manager() as manager: + qdrant_contexts = manager.list() + supabase_contexts = manager.list() + partial_func1 = partial(qdrant_context_processing, course_name=course_name, result_contexts=qdrant_contexts) + partial_func2 = partial(supabase_context_padding, course_name=course_name, sql_data=supabase_data, result_docs=supabase_contexts) + + with self.executor as executor: + executor.map(partial_func1, found_docs[5:]) + executor.map(partial_func2, found_docs[:5]) + + unique_contexts = list(set(tuple(item.items()) for item in list(supabase_contexts) + list(qdrant_contexts))) + unique_contexts = [dict(item) for item in unique_contexts] + print(f"⏰ Context padding runtime: {(time.monotonic() - start_time):.2f} seconds") + return unique_contexts + + +def qdrant_context_processing(doc, course_name, result_contexts): + """ + Re-factor QDRANT objects into Supabase objects and append to result_docs + """ + #print("inside qdrant context processing") + context_dict = { + 'text': doc.page_content, + 'pagenumber': doc.metadata.get('pagenumber', ''), + 'readable_filename': doc.metadata.get('readable_filename', ''), + 'course_name': course_name, + 's3_path': doc.metadata.get('s3_path', ''), + 'url': doc.metadata.get('url', ''), + 'base_url': doc.metadata.get('base_url', '') + } + result_contexts.append(context_dict) + +def supabase_context_padding(doc, course_name, sql_data, result_docs): + """ + Does context padding for given doc. + """ + # search the document in sql_data + url_match = next((item for item in sql_data if item.get('url') == doc.metadata.get('url')), None) + supabase_doc = url_match or next((item for item in sql_data if item.get('s3_path') == doc.metadata.get('s3_path')), None) + + # create a dictionary + if supabase_doc: + contexts = supabase_doc['contexts'] + filename = supabase_doc['readable_filename'] + + if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys(): + # pad contexts by chunk index + 3 and - 3 + target_chunk_index = doc.metadata['chunk_index'] + for context in contexts: + curr_chunk_index = context['chunk_index'] + if (target_chunk_index - 3 <= curr_chunk_index <= target_chunk_index + 3): + context['readable_filename'] = filename + context['course_name'] = course_name + context['s3_path'] = supabase_doc['s3_path'] + context['url'] = supabase_doc['url'] + context['base_url'] = supabase_doc['base_url'] + context.pop('embedding', None) + + result_docs.append(context) + + elif doc.metadata['pagenumber'] != '': + # pad contexts belonging to same page number + pagenumber = doc.metadata['pagenumber'] + + for context in contexts: + # pad contexts belonging to same page number + if int(context['pagenumber']) == pagenumber: + context['readable_filename'] = filename + context['course_name'] = course_name + context['s3_path'] = supabase_doc['s3_path'] + context['url'] = supabase_doc['url'] + context['base_url'] = supabase_doc['base_url'] + context.pop('embedding', None) + result_docs.append(context) + + else: + # refactor as a Supabase object and append + context_dict = { + 'text': doc.page_content, + 'pagenumber': doc.metadata.get('pagenumber', ''), + 'readable_filename': doc.metadata.get('readable_filename', ''), + 'course_name': course_name, + 's3_path': doc.metadata.get('s3_path', ''), + 'base_url': doc.metadata.get('base_url', ''), + 'url': doc.metadata.get('url', '') + } + result_docs.append(context_dict) + diff --git a/ai_ta_backend/service/sentry_service.py b/ai_ta_backend/service/sentry_service.py index 53b780b0..03c25a4c 100644 --- a/ai_ta_backend/service/sentry_service.py +++ b/ai_ta_backend/service/sentry_service.py @@ -16,7 +16,9 @@ def __init__(self, dsn: str): # Set profiles_sample_rate to 1.0 to profile 100% of sampled transactions. # We recommend adjusting this value in production. profiles_sample_rate=1.0, + environment="development", # 'production', 'staging', 'development', 'testing enable_tracing=True) def capture_exception(self, exception: Exception): sentry_sdk.capture_exception(exception) + diff --git a/ai_ta_backend/utils/context_parent_doc_padding.py b/ai_ta_backend/utils/context_parent_doc_padding.py index fc0ba19c..e8b45183 100644 --- a/ai_ta_backend/utils/context_parent_doc_padding.py +++ b/ai_ta_backend/utils/context_parent_doc_padding.py @@ -1,28 +1,36 @@ -import os +#import os import time -from concurrent.futures import ProcessPoolExecutor +#from concurrent.futures import ProcessPoolExecutor from functools import partial from multiprocessing import Manager +from ai_ta_backend.database.sql import SQLDatabase +from ai_ta_backend.executors.process_pool_executor import ProcessPoolExecutorAdapter -DOCUMENTS_TABLE = os.environ['SUPABASE_DOCUMENTS_TABLE'] + +# DOCUMENTS_TABLE = os.environ['SUPABASE_DOCUMENTS_TABLE'] # SUPABASE_CLIENT = supabase.create_client(supabase_url=os.environ['SUPABASE_URL'], # supabase_key=os.environ['SUPABASE_API_KEY']) # type: ignore -def context_parent_doc_padding(found_docs, search_query, course_name): +def context_parent_doc_padding(found_docs, course_name): """ Takes top N contexts acquired from QRANT similarity search and pads them """ print("inside main context padding") start_time = time.monotonic() - + #executor = ProcessPoolExecutorAdapter(max_workers=10) + with Manager() as manager: qdrant_contexts = manager.list() supabase_contexts = manager.list() partial_func1 = partial(qdrant_context_processing, course_name=course_name, result_contexts=qdrant_contexts) partial_func2 = partial(supabase_context_padding, course_name=course_name, result_docs=supabase_contexts) - with ProcessPoolExecutor() as executor: + # with ProcessPoolExecutor() as executor: + # executor.map(partial_func1, found_docs[5:]) + # executor.map(partial_func2, found_docs[:5]) + + with ProcessPoolExecutorAdapter() as executor: executor.map(partial_func1, found_docs[5:]) executor.map(partial_func2, found_docs[:5]) @@ -41,50 +49,54 @@ def context_parent_doc_padding(found_docs, search_query, course_name): def qdrant_context_processing(doc, course_name, result_contexts): """ Re-factor QDRANT objects into Supabase objects and append to result_docs - """ + """ context_dict = { 'text': doc.page_content, 'embedding': '', 'pagenumber': doc.metadata['pagenumber'], 'readable_filename': doc.metadata['readable_filename'], 'course_name': course_name, - 's3_path': doc.metadata['s3_path'], - 'base_url': doc.metadata['base_url'] + 's3_path': doc.metadata['s3_path'] } if 'url' in doc.metadata.keys(): context_dict['url'] = doc.metadata['url'] else: context_dict['url'] = '' + if 'base_url' in doc.metadata.keys(): + context_dict['base_url'] = doc.metadata['url'] + else: + context_dict['base_url'] = '' + result_contexts.append(context_dict) - return result_contexts def supabase_context_padding(doc, course_name, result_docs): """ Does context padding for given doc. - """ + """ + SQL_DB = SQLDatabase() # query by url or s3_path if 'url' in doc.metadata.keys() and doc.metadata['url']: parent_doc_id = doc.metadata['url'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('url', parent_doc_id).execute() - + # response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', + # course_name).eq('url', parent_doc_id).execute() + response = SQL_DB.getMaterialsForCourseAndKeyAndValue(course_name=course_name, key='url', value=parent_doc_id) else: parent_doc_id = doc.metadata['s3_path'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('s3_path', - parent_doc_id).execute() - + # response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', + # course_name).eq('s3_path', + # parent_doc_id).execute() + response = SQL_DB.getMaterialsForCourseAndS3Path(course_name=course_name, s3_path=parent_doc_id) data = response.data - + if len(data) > 0: # do the padding filename = data[0]['readable_filename'] contexts = data[0]['contexts'] #print("no of contexts within the og doc: ", len(contexts)) - + if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys(): #print("inside chunk index") # pad contexts by chunk index + 3 and - 3 @@ -132,3 +144,5 @@ def supabase_context_padding(doc, course_name, result_docs): context_dict['url'] = '' result_docs.append(context_dict) + +