Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding parent document retrieval in default RAG pipeline #233

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
9 changes: 7 additions & 2 deletions ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def getAllMaterialsForCourse(self, course_name: str):
'course_name', course_name).execute()

def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str):
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq(
's3_path', s3_path).eq('course_name', course_name).execute()

def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str):
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(
return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, readable_filename, base_url, url, contexts").eq(
key, value).eq('course_name', course_name).execute()

def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str):
Expand Down Expand Up @@ -110,3 +110,8 @@ def updateProjects(self, course_name: str, data: dict):
def getConversation(self, course_name: str, key: str, value: str):
return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute()

def getDocsByURLs(self, course_name: str, urls: list):
return self.supabase_client.table("documents").select("*").eq("course_name", course_name).in_("url", urls).execute()

def getDocsByS3Paths(self, course_name: str, s3_paths: list):
return self.supabase_client.table("documents").select("*").eq("course_name", course_name).in_("s3_path", s3_paths).execute()
7 changes: 7 additions & 0 deletions ai_ta_backend/executors/process_pool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,10 @@ def submit(self, fn, *args, **kwargs):

def map(self, fn, *iterables, timeout=None, chunksize=1):
return self.executor.map(fn, *iterables, timeout=timeout, chunksize=chunksize)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.executor.shutdown(wait=True)

58 changes: 58 additions & 0 deletions ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,64 @@ def getTopContexts(service: RetrievalService) -> Response:
return response


@app.route('/getTopContextsv2', methods=['POST'])
def getTopContextsv2(service: RetrievalService) -> Response:
"""Get most relevant contexts for a given search query.

Return value

## POST body
course name (optional) str
A json response with TBD fields.
search_query
token_limit
doc_groups

Returns
-------
JSON
A json response with TBD fields.
Metadata fields
* pagenumber_or_timestamp
* readable_filename
* s3_pdf_path

Example:
[
{
'readable_filename': 'Lumetta_notes',
'pagenumber_or_timestamp': 'pg. 19',
's3_pdf_path': '/courses/<course>/Lumetta_notes.pdf',
'text': 'In FSM, we do this...'
},
]

Raises
------
Exception
Testing how exceptions are handled.
"""
data = request.get_json()
search_query: str = data.get('search_query', '')
course_name: str = data.get('course_name', '')
token_limit: int = data.get('token_limit', 3000)
doc_groups: List[str] = data.get('doc_groups', [])

if search_query == '' or course_name == '':
# proper web error "400 Bad request"
abort(
400,
description=
f"Missing one or more required parameters: 'search_query' and 'course_name' must be provided. Search query: `{search_query}`, Course name: `{course_name}`"
)

found_documents = service.getTopContextsv2(search_query, course_name, token_limit, doc_groups)

response = jsonify(found_documents)
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/getAll', methods=['GET'])
def getAll(service: RetrievalService) -> Response:
"""Get all course materials based on the course_name
Expand Down
1 change: 1 addition & 0 deletions ai_ta_backend/service/export_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''):
"""

response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'documents')
print("response count: ", response.count)
# add a condition to route to direct download or s3 download
if response.count > 500:
# call background task to upload to s3
Expand Down
Loading