diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index 075b854a..08aa37ed 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -157,62 +157,22 @@ def get_context_stuffed_prompt(self, user_question: str, course_name: str, top_n # "Please answer the following question. It's good to quote 'your documents' directly, something like 'from ABS source it says XYZ' Feel free to say you don't know. \nHere's a few passages of the high quality 'your documents':\n" return stuffed_prompt - - # def ai_summary(self, text: List[str], metadata: List[Dict[str, Any]]) -> List[str]: - # """ - # Given a textual input, return a summary of the text. - # """ - # #print("in AI SUMMARY") - # requests = [] - # for i in range(len(text)): - # dictionary = { - # "model": "gpt-3.5-turbo", - # "messages": [{ - # "role": - # "system", - # "content": - # "You are a factual summarizer of partial documents. Stick to the facts (including partial info when necessary to avoid making up potentially incorrect details), and say I don't know when necessary." - # }, { - # "role": - # "user", - # "content": - # f"Provide a descriptive summary of the given text:\n{text[i]}\nThe summary should cover all the key points, while also condensing the information into a concise format. The length of the summary should not exceed 3 sentences.", - # }], - # "n": 1, - # "max_tokens": 600, - # "metadata": metadata[i] - # } - # requests.append(dictionary) - - # oai = OpenAIAPIProcessor(input_prompts_list=requests, - # request_url='https://api.openai.com/v1/chat/completions', - # api_key=os.getenv("OPENAI_API_KEY"), - # max_requests_per_minute=1500, - # max_tokens_per_minute=90000, - # token_encoding_name='cl100k_base', - # max_attempts=5, - # logging_level=20) - - # asyncio.run(oai.process_api_requests_from_file()) - # #results: list[str] = oai.results - # #print(f"Cleaned results: {oai.cleaned_results}") - # summary = oai.cleaned_results - # return summary def bulk_ingest(self, s3_paths: Union[List[str], str], course_name: str, **kwargs) -> Dict[str, List[str]]: - success_status = {"success_ingest": [], "failure_ingest": []} - - def ingest(file_ext_mapping, s3_path, *args, **kwargs): - handler = file_ext_mapping.get(Path(s3_path).suffix) - if handler: - ret = handler(s3_path, *args, **kwargs) - if ret != "Success": - success_status['failure_ingest'].append(s3_path) - else: - success_status['success_ingest'].append(s3_path) + def _ingest_single(file_ingest_methods, s3_path, *args, **kwargs): + """Handle running an arbitrary ingest function for an individual file.""" + handler = file_ingest_methods.get(Path(s3_path).suffix) + if handler: + # RUN INGEST METHOD + ret = handler(s3_path, *args, **kwargs) + if ret != "Success": + success_status['failure_ingest'].append(s3_path) + else: + success_status['success_ingest'].append(s3_path) - file_ext_mapping = { + # πŸ‘‡πŸ‘‡πŸ‘‡πŸ‘‡ ADD NEW INGEST METHODSE E HERπŸ‘‡πŸ‘‡πŸ‘‡πŸ‘‡πŸŽ‰ + file_ingest_methods = { '.html': self._ingest_html, '.py': self._ingest_single_py, '.vtt': self._ingest_single_vtt, @@ -225,35 +185,39 @@ def ingest(file_ext_mapping, s3_path, *args, **kwargs): '.pptx': self._ingest_single_ppt, } + # Ingest methods via MIME type (more general than filetype) + mimetype_ingest_methods = { + 'video': self._ingest_single_video, + 'audio': self._ingest_single_video, + 'text': self._ingest_single_txt, + } + # πŸ‘†πŸ‘†πŸ‘†πŸ‘† ADD NEW INGEST METHODS ERE πŸ‘†πŸ‘†πŸ‘‡οΏ½DS πŸ‘‡οΏ½πŸŽ‰ + + success_status = {"success_ingest": [], "failure_ingest": []} try: - if isinstance(s3_paths, str): - s3_paths = [s3_paths] - - for s3_path in s3_paths: - with NamedTemporaryFile(suffix=Path(s3_path).suffix) as tmpfile: - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - mime_type = mimetypes.guess_type(tmpfile.name)[0] - category, _ = mime_type.split('/') - match_file_ext = "." + _ - - if category in ['video', 'audio']: - ret = self._ingest_single_video(s3_path, course_name) - if ret != "Success": - success_status['failure_ingest'].append(s3_path) - else: - success_status['success_ingest'].append(s3_path) - elif category == 'text' and match_file_ext not in file_ext_mapping.keys(): - print(category, match_file_ext) - ret = self._ingest_single_txt(s3_path, course_name) - if ret != "Success": - success_status['failure_ingest'].append(s3_path) - else: - success_status['success_ingest'].append(s3_path) - else: - ingest(file_ext_mapping, s3_path, course_name, kwargs=kwargs) + if isinstance(s3_paths, str): + s3_paths = [s3_paths] - return success_status + for s3_path in s3_paths: + with NamedTemporaryFile(suffix=Path(s3_path).suffix) as tmpfile: + self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) + mime_type = mimetypes.guess_type(tmpfile.name, strict=False)[0] + mime_category, extension = mime_type.split('/') + file_ext = "." + extension + + if file_ext in file_ingest_methods: + # Use specialized functions when possible, fallback to mimetype. Else raise error. + _ingest_single(file_ingest_methods, s3_path, course_name, kwargs=kwargs) + elif mime_category in mimetype_ingest_methods: + # mime type + _ingest_single(mimetype_ingest_methods, s3_path, course_name, kwargs=kwargs) + else: + # failure + success_status['failure_ingest'].append(f"File ingest not supported for Mimetype: {mime_type}, with MimeCategory: {mime_category}, with file extension: {file_ext} for s3_path: {s3_path}") + continue + + return success_status except Exception as e: success_status['failure_ingest'].append(f"MAJOR ERROR IN /bulk_ingest: Error: {str(e)}") return success_status