diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 89b58d64..bf750aa3 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -78,6 +78,9 @@ def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte( 'id', first_id).lte('id', last_id).order('id', desc=False).limit(25).execute() - def getDocsForIdsGte(self, course_name: str, first_id: int): - return self.supabase_client.table("documents").select("*").eq("course_name", course_name).gte('id', first_id).order( - 'id', desc=False).limit(100).execute() + def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): + return self.supabase_client.table("documents").select(fields).eq("course_name", course_name).gte( + 'id', first_id).order('id', desc=False).limit(limit).execute() + + def insertProjectInfo(self, project_info): + return self.supabase_client.table("projects").insert(project_info).execute() diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index 0c3718e0..79bf9d74 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -7,6 +7,7 @@ import pandas as pd import requests +from injector import inject from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.sql import SQLDatabase @@ -16,6 +17,7 @@ class ExportService: + @inject def __init__(self, sql: SQLDatabase, s3: AWSStorage, sentry=SentryService): self.sql = sql self.s3 = s3 @@ -257,7 +259,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: print(e) - sentry_sdk.capture_exception(e) + self.sentry.capture_exception(e) return {"response": "Error downloading file!"} else: return {"response": "No data found between the given dates."} diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index 973f17f5..5b6c4e38 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -10,6 +10,7 @@ from langchain.embeddings import OpenAIEmbeddings from nomic import AtlasProject, atlas +from ai_ta_backend.database.sql import SQLDatabase from ai_ta_backend.service.sentry_service import SentryService LOCK_EXCEPTIONS = [ @@ -59,9 +60,10 @@ def backoff_strategy(): class NomicService(): @inject - def __init__(self, sentry: SentryService): + def __init__(self, sentry: SentryService, sql: SQLDatabase): nomic.login(os.getenv('NOMIC_API_KEY')) self.sentry = sentry + self.sql = sql @backoff.on_exception(backoff_strategy, Exception, @@ -424,22 +426,16 @@ def create_document_map(self, course_name: str): # nomic.login(os.getenv('NOMIC_API_KEY')) NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - # initialize supabase - supabase_client = supabase.create_client( # type: ignore - supabase_url=os.getenv('SUPABASE_URL'), # type: ignore - supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore - try: # check if map exists - response = supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() + + response = self.sql.getProjectsMapForCourse(course_name) if response.data: return "Map already exists for this course." # fetch relevant document data from Supabase - response = supabase_client.table("documents").select("id", - count="exact").eq("course_name", - course_name).order('id', - desc=False).execute() + response = self.sql.getDocumentsBetweenDates(course_name, '', '', "documents") + if not response.count: return "No documents found for this course." @@ -458,9 +454,9 @@ def create_document_map(self, course_name: str): # iteratively query in batches of 25 while curr_total_doc_count < total_doc_count: - response = supabase_client.table("documents").select( - "id, created_at, s3_path, url, readable_filename, contexts").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(25).execute() + response = self.sql.getDocsForIdsGte(course_name, first_id, + "id, created_at, s3_path, url, readable_filename, contexts", 25) + df = pd.DataFrame(response.data) combined_dfs.append(df) # list of dfs @@ -519,7 +515,7 @@ def create_document_map(self, course_name: str): project_id = project.id project.rebuild_maps() project_info = {'course_name': course_name, 'doc_map_id': project_id} - response = supabase_client.table("projects").insert(project_info).execute() + response = self.sql.insertProjectInfo(project_info) print("Response from supabase: ", response) return "success" else: