Skip to content

Commit

Permalink
Adding injection to ExportService __init__, and add SQLDatabase injec…
Browse files Browse the repository at this point in the history
…tion to document map functions in NomicService
  • Loading branch information
rohan-uiuc committed Mar 7, 2024
1 parent a14cc44 commit ad220a6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
9 changes: 6 additions & 3 deletions ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id
return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte(
'id', first_id).lte('id', last_id).order('id', desc=False).limit(25).execute()

def getDocsForIdsGte(self, course_name: str, first_id: int):
return self.supabase_client.table("documents").select("*").eq("course_name", course_name).gte('id', first_id).order(
'id', desc=False).limit(100).execute()
def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100):
return self.supabase_client.table("documents").select(fields).eq("course_name", course_name).gte(
'id', first_id).order('id', desc=False).limit(limit).execute()

def insertProjectInfo(self, project_info):
return self.supabase_client.table("projects").insert(project_info).execute()
4 changes: 3 additions & 1 deletion ai_ta_backend/service/export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pandas as pd
import requests
from injector import inject

from ai_ta_backend.database.aws import AWSStorage
from ai_ta_backend.database.sql import SQLDatabase
Expand All @@ -16,6 +17,7 @@

class ExportService:

@inject
def __init__(self, sql: SQLDatabase, s3: AWSStorage, sentry=SentryService):
self.sql = sql
self.s3 = s3
Expand Down Expand Up @@ -257,7 +259,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''):
return {"response": (zip_file_path, zip_filename, os.getcwd())}
except Exception as e:
print(e)
sentry_sdk.capture_exception(e)
self.sentry.capture_exception(e)
return {"response": "Error downloading file!"}
else:
return {"response": "No data found between the given dates."}
26 changes: 11 additions & 15 deletions ai_ta_backend/service/nomic_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from langchain.embeddings import OpenAIEmbeddings
from nomic import AtlasProject, atlas

from ai_ta_backend.database.sql import SQLDatabase
from ai_ta_backend.service.sentry_service import SentryService

LOCK_EXCEPTIONS = [
Expand Down Expand Up @@ -59,9 +60,10 @@ def backoff_strategy():
class NomicService():

@inject
def __init__(self, sentry: SentryService):
def __init__(self, sentry: SentryService, sql: SQLDatabase):
nomic.login(os.getenv('NOMIC_API_KEY'))
self.sentry = sentry
self.sql = sql

@backoff.on_exception(backoff_strategy,
Exception,
Expand Down Expand Up @@ -424,22 +426,16 @@ def create_document_map(self, course_name: str):
# nomic.login(os.getenv('NOMIC_API_KEY'))
NOMIC_MAP_NAME_PREFIX = 'Document Map for '

# initialize supabase
supabase_client = supabase.create_client( # type: ignore
supabase_url=os.getenv('SUPABASE_URL'), # type: ignore
supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore

try:
# check if map exists
response = supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute()

response = self.sql.getProjectsMapForCourse(course_name)
if response.data:
return "Map already exists for this course."

# fetch relevant document data from Supabase
response = supabase_client.table("documents").select("id",
count="exact").eq("course_name",
course_name).order('id',
desc=False).execute()
response = self.sql.getDocumentsBetweenDates(course_name, '', '', "documents")

if not response.count:
return "No documents found for this course."

Expand All @@ -458,9 +454,9 @@ def create_document_map(self, course_name: str):
# iteratively query in batches of 25
while curr_total_doc_count < total_doc_count:

response = supabase_client.table("documents").select(
"id, created_at, s3_path, url, readable_filename, contexts").eq("course_name", course_name).gte(
'id', first_id).order('id', desc=False).limit(25).execute()
response = self.sql.getDocsForIdsGte(course_name, first_id,
"id, created_at, s3_path, url, readable_filename, contexts", 25)

df = pd.DataFrame(response.data)
combined_dfs.append(df) # list of dfs

Expand Down Expand Up @@ -519,7 +515,7 @@ def create_document_map(self, course_name: str):
project_id = project.id
project.rebuild_maps()
project_info = {'course_name': course_name, 'doc_map_id': project_id}
response = supabase_client.table("projects").insert(project_info).execute()
response = self.sql.insertProjectInfo(project_info)
print("Response from supabase: ", response)
return "success"
else:
Expand Down

0 comments on commit ad220a6

Please sign in to comment.