Skip to content

Commit

Permalink
fix killer bug that prevented mimetypes from working
Browse files Browse the repository at this point in the history
  • Loading branch information
KastanDay committed Sep 28, 2023
1 parent 5792647 commit c879b71
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from importlib import metadata
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import boto3
import fitz
Expand Down Expand Up @@ -161,10 +161,9 @@ def get_context_stuffed_prompt(self, user_question: str, course_name: str, top_n


def bulk_ingest(self, s3_paths: Union[List[str], str], course_name: str, **kwargs) -> Dict[str, List[str]]:
def _ingest_single(file_ingest_methods, s3_path, *args, **kwargs):
def _ingest_single(ingest_methods: Dict[str, Callable], s3_path, *args, **kwargs):
"""Handle running an arbitrary ingest function for an individual file."""
handler = file_ingest_methods.get(Path(s3_path).suffix)
print(f"Using ingest method: {handler} ||| for file: {s3_path}")
handler = ingest_methods.get(Path(s3_path).suffix)
if handler:
# RUN INGEST METHOD
ret = handler(s3_path, *args, **kwargs)
Expand Down Expand Up @@ -194,33 +193,23 @@ def _ingest_single(file_ingest_methods, s3_path, *args, **kwargs):
}
# 👆👆👆👆 ADD NEW INGEST METHODS ERE 👆👆👇�DS 👇�🎉

print(f"Top of bulk_ingest. S3 paths {s3_paths}")
print(f"Top of bulk_ingest. Course_name {course_name}")
print(f"Top of bulk_ingest. kwargs {kwargs}")

print(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}")
success_status = {"success_ingest": [], "failure_ingest": []}
try:
if isinstance(s3_paths, str):
s3_paths = [s3_paths]


for s3_path in s3_paths:
file_extension = Path(s3_path).suffix
with NamedTemporaryFile(suffix=file_extension) as tmpfile:
self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile)
mime_type = mimetypes.guess_type(tmpfile.name, strict=False)[0]
mime_category, extension = mime_type.split('/')
# file_ext = "." + extension
print(f"Mime mime_category: {mime_category}")
print(f"Mime type: {mime_type}")
print(f"file extension: {file_extension}")

if file_extension in file_ingest_methods:
# Use specialized functions when possible, fallback to mimetype. Else raise error.
print(f"Using SPECIFIC file ingest methods")
_ingest_single(file_ingest_methods, s3_path, course_name, kwargs=kwargs)
elif mime_category in mimetype_ingest_methods:
print(f"Using GENERAL Mimetype ingest methods")
# mime type
_ingest_single(mimetype_ingest_methods, s3_path, course_name, kwargs=kwargs)
else:
Expand Down

0 comments on commit c879b71

Please sign in to comment.