From 144a37baea6f74eb53bb44235ffb600be774c19a Mon Sep 17 00:00:00 2001 From: rohanmarwaha Date: Fri, 17 May 2024 11:39:41 -0500 Subject: [PATCH 01/11] WIP first draft for revamp, need to introduce models/types --- ai_ta_backend/database/base_sql.py | 69 +++++++++++ ai_ta_backend/database/base_storage.py | 15 +++ ai_ta_backend/database/base_vector.py | 12 ++ .../database/database_impl/__init__.py | 0 .../{sql.py => database_impl/sql/supabase.py} | 4 +- .../{ => database_impl/storage}/aws.py | 4 +- .../vector/qdrant.py} | 4 +- ai_ta_backend/main.py | 54 ++++++-- ai_ta_backend/service/export_service.py | 9 +- ai_ta_backend/service/nomic_service.py | 2 +- ai_ta_backend/service/retrieval_service.py | 116 ++++++++++-------- ai_ta_backend/service/workflow_service.py | 2 +- 12 files changed, 217 insertions(+), 74 deletions(-) create mode 100644 ai_ta_backend/database/base_sql.py create mode 100644 ai_ta_backend/database/base_storage.py create mode 100644 ai_ta_backend/database/base_vector.py create mode 100644 ai_ta_backend/database/database_impl/__init__.py rename ai_ta_backend/database/{sql.py => database_impl/sql/supabase.py} (98%) rename ai_ta_backend/database/{ => database_impl/storage}/aws.py (91%) rename ai_ta_backend/database/{vector.py => database_impl/vector/qdrant.py} (96%) diff --git a/ai_ta_backend/database/base_sql.py b/ai_ta_backend/database/base_sql.py new file mode 100644 index 00000000..ce808e1a --- /dev/null +++ b/ai_ta_backend/database/base_sql.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Tuple +from postgrest.base_request_builder import APIResponse + +class BaseSQLDatabase(ABC): + + @abstractmethod + def getAllMaterialsForCourse(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + pass + @abstractmethod + def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + pass + @abstractmethod + def getProjectsMapForCourse(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str, table_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getAllFromTableForDownloadType(self, course_name: str, download_type: str, first_id: int) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def insertProjectInfo(self, project_info) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getAllFromLLMConvoMonitor(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getDocMapFromProjects(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getConvoMapFromProjects(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def updateProjects(self, course_name: str, data: dict) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getLatestWorkflowId(self) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def lockWorkflow(self, id: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def deleteLatestWorkflowId(self, id: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def unlockWorkflow(self, id: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass + @abstractmethod + def getConversation(self, course_name: str, key: str, value: str) -> APIResponse[Tuple[Dict[str, Any], int]]: + pass \ No newline at end of file diff --git a/ai_ta_backend/database/base_storage.py b/ai_ta_backend/database/base_storage.py new file mode 100644 index 00000000..733b387d --- /dev/null +++ b/ai_ta_backend/database/base_storage.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod + +class BaseStorageDatabase(ABC): + @abstractmethod + def upload_file(self, file_path: str, bucket_name: str, object_name: str): + pass + @abstractmethod + def download_file(self, object_name: str, bucket_name: str, file_path: str): + pass + @abstractmethod + def delete_file(self, bucket_name: str, s3_path: str): + pass + @abstractmethod + def generatePresignedUrl(self, object: str, bucket_name: str, s3_path: str, expiration: int = 3600): + pass \ No newline at end of file diff --git a/ai_ta_backend/database/base_vector.py b/ai_ta_backend/database/base_vector.py new file mode 100644 index 00000000..551fd328 --- /dev/null +++ b/ai_ta_backend/database/base_vector.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import List + +class BaseVectorDatabase(ABC): + + @abstractmethod + def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): + pass + + @abstractmethod + def delete_data(self, collection_name: str, key: str, value: str): + pass diff --git a/ai_ta_backend/database/database_impl/__init__.py b/ai_ta_backend/database/database_impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/database_impl/sql/supabase.py similarity index 98% rename from ai_ta_backend/database/sql.py rename to ai_ta_backend/database/database_impl/sql/supabase.py index 6f7ae01d..a20ff756 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/database_impl/sql/supabase.py @@ -3,8 +3,10 @@ import supabase from injector import inject +from ai_ta_backend.database.base_sql import BaseSQLDatabase -class SQLDatabase: + +class SQLDatabase(BaseSQLDatabase): @inject def __init__(self): diff --git a/ai_ta_backend/database/aws.py b/ai_ta_backend/database/database_impl/storage/aws.py similarity index 91% rename from ai_ta_backend/database/aws.py rename to ai_ta_backend/database/database_impl/storage/aws.py index 68e61b68..a7042116 100644 --- a/ai_ta_backend/database/aws.py +++ b/ai_ta_backend/database/database_impl/storage/aws.py @@ -3,8 +3,10 @@ import boto3 from injector import inject +from ai_ta_backend.database.base_storage import BaseStorageDatabase -class AWSStorage: + +class AWSStorage(BaseStorageDatabase): @inject def __init__(self): diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/database_impl/vector/qdrant.py similarity index 96% rename from ai_ta_backend/database/vector.py rename to ai_ta_backend/database/database_impl/vector/qdrant.py index f9d002ec..b6cc6bd5 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/database_impl/vector/qdrant.py @@ -6,10 +6,12 @@ from langchain.vectorstores import Qdrant from qdrant_client import QdrantClient, models +from ai_ta_backend.database.base_vector import BaseVectorDatabase + OPENAI_API_TYPE = "azure" # "openai" or "azure" -class VectorDatabase(): +class VectorDatabase(BaseVectorDatabase): """ Contains all methods for building and using vector databases. """ diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 0085abd5..82901304 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -17,10 +17,10 @@ from flask_executor import Executor from flask_injector import FlaskInjector, RequestScope from injector import Binder, SingletonScope +from ai_ta_backend.database.base_sql import BaseSQLDatabase +from ai_ta_backend.database.base_storage import BaseStorageDatabase +from ai_ta_backend.database.base_vector import BaseVectorDatabase -from ai_ta_backend.database.aws import AWSStorage -from ai_ta_backend.database.sql import SQLDatabase -from ai_ta_backend.database.vector import VectorDatabase from ai_ta_backend.executors.flask_executor import ( ExecutorInterface, FlaskExecutorAdapter, @@ -473,20 +473,48 @@ def run_flow(service: WorkflowService) -> Response: def configure(binder: Binder) -> None: - binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) - binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) - binder.bind(SentryService, to=SentryService, scope=SingletonScope) - binder.bind(NomicService, to=NomicService, scope=SingletonScope) - binder.bind(ExportService, to=ExportService, scope=SingletonScope) - binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) - binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) - binder.bind(SQLDatabase, to=SQLDatabase, scope=SingletonScope) - binder.bind(AWSStorage, to=AWSStorage, scope=SingletonScope) + vector_bound = False + sql_bound = False + storage_bound = False + + # Conditionally bind databases based on the availability of their respective secrets + if any(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any(os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): + binder.bind(BaseVectorDatabase, to=BaseVectorDatabase, scope=SingletonScope) + vector_bound = True + + if any(os.getenv(key) for key in ["SUPABASE_URL", "SUPABASE_API_KEY", "SUPABASE_DOCUMENTS_TABLE"]) or any(["SQLITE_DB_PATH", "SQLITE_DB_NAME", "SQLITE_DOCUMENTS_TABLE"]): + binder.bind(BaseSQLDatabase, to=BaseSQLDatabase, scope=SingletonScope) + sql_bound = True + + if any(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]) or any(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): + binder.bind(BaseStorageDatabase, to=BaseStorageDatabase, scope=SingletonScope) + storage_bound = True + + + # Conditionally bind services based on the availability of their respective secrets + if os.getenv("NOMIC_API_KEY"): + binder.bind(NomicService, to=NomicService, scope=SingletonScope) + + if os.getenv("POSTHOG_API_KEY"): + binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) + + if os.getenv("SENTRY_DSN"): + binder.bind(SentryService, to=SentryService, scope=SingletonScope) + + if os.getenv("EMAIL_SENDER"): + binder.bind(ExportService, to=ExportService, scope=SingletonScope) + + if os.getenv("N8N_URL"): + binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) + + if vector_bound and sql_bound and storage_bound: + binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) + + # Always bind the executor and its adapters binder.bind(ExecutorInterface, to=FlaskExecutorAdapter(executor), scope=SingletonScope) binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter, scope=SingletonScope) binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter, scope=SingletonScope) - FlaskInjector(app=app, modules=[configure]) if __name__ == '__main__': diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index 85f01118..15c33ccd 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -8,8 +8,8 @@ import requests from injector import inject -from ai_ta_backend.database.aws import AWSStorage -from ai_ta_backend.database.sql import SQLDatabase +from ai_ta_backend.database.database_impl.storage.aws import AWSStorage +from ai_ta_backend.database.database_impl.sql.supabase import SQLDatabase from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.emails import send_email @@ -34,8 +34,9 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): """ response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'documents') + # add a condition to route to direct download or s3 download - if response.count > 500: + if response.count and response.count > 500: # call background task to upload to s3 filename = course_name + '_' + str(uuid.uuid4()) + '_documents.zip' @@ -47,7 +48,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): else: # Fetch data - if response.count > 0: + if response.count and response.count > 0: # batch download total_doc_count = response.count first_id = response.data[0]['id'] diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index 80ca86ca..a0dedd45 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -11,7 +11,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings from nomic import AtlasProject, atlas -from ai_ta_backend.database.sql import SQLDatabase +from ai_ta_backend.database.database_impl.sql.supabase import SQLDatabase from ai_ta_backend.service.sentry_service import SentryService LOCK_EXCEPTIONS = [ diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index c53bcefb..5ab9a0ea 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -2,17 +2,16 @@ import os import time import traceback -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import openai from injector import inject from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.schema import Document - -from ai_ta_backend.database.aws import AWSStorage -from ai_ta_backend.database.sql import SQLDatabase -from ai_ta_backend.database.vector import VectorDatabase +from ai_ta_backend.database.base_sql import BaseSQLDatabase +from ai_ta_backend.database.base_storage import BaseStorageDatabase +from ai_ta_backend.database.base_vector import BaseVectorDatabase from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService @@ -25,8 +24,8 @@ class RetrievalService: """ @inject - def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, posthog: PosthogService, - sentry: SentryService, nomicService: NomicService): + def __init__(self, vdb: BaseVectorDatabase, sqlDb: BaseSQLDatabase, aws: BaseStorageDatabase, posthog: Optional[PosthogService], + sentry: Optional[SentryService], nomicService: Optional[NomicService]): self.vdb = vdb self.sqlDb = sqlDb self.aws = aws @@ -104,18 +103,19 @@ def getTopContexts(self, if len(valid_docs) == 0: return [] - self.posthog.capture( - event_name="getTopContexts_success_DI", - properties={ - "user_query": search_query, - "course_name": course_name, - "token_limit": token_limit, - "total_tokens_used": token_counter, - "total_contexts_used": len(valid_docs), - "total_unique_docs_retrieved": len(found_docs), - "getTopContext_total_latency_sec": time.monotonic() - start_time_overall, - }, - ) + if self.posthog is not None: + self.posthog.capture( + event_name="getTopContexts_success_DI", + properties={ + "user_query": search_query, + "course_name": course_name, + "token_limit": token_limit, + "total_tokens_used": token_counter, + "total_contexts_used": len(valid_docs), + "total_unique_docs_retrieved": len(found_docs), + "getTopContext_total_latency_sec": time.monotonic() - start_time_overall, + }, + ) return self.format_for_json(valid_docs) except Exception as e: @@ -124,7 +124,8 @@ def getTopContexts(self, err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.print_exc} \n{e}" # type: ignore traceback.print_exc() print(err) - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) return err def getAll( @@ -179,7 +180,8 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str): except Exception as e: err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore print(err) - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) return err def delete_from_s3(self, bucket_name: str, s3_path: str): @@ -189,7 +191,8 @@ def delete_from_s3(self, bucket_name: str, s3_path: str): print(f"AWS response: {response}") except Exception as e: print("Error in deleting file from s3:", e) - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) def delete_from_qdrant(self, identifier_key: str, identifier_value: str): try: @@ -202,7 +205,8 @@ def delete_from_qdrant(self, identifier_key: str, identifier_value: str): pass else: print("Error in deleting file from Qdrant:", e) - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) def getTopContextsWithMQR(self, search_query: str, @@ -325,27 +329,32 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, try: print(f"Nomic delete. Course: {course_name} using {identifier_key}: {identifier_value}") response = self.sqlDb.getMaterialsForCourseAndKeyAndValue(course_name, identifier_key, identifier_value) - if not response.data: + data = response.data + if not data: raise Exception(f"No materials found for {course_name} using {identifier_key}: {identifier_value}") - data = response.data[0] # single record fetched + data = data[0] # single record fetched nomic_ids_to_delete = [str(data['id']) + "_" + str(i) for i in range(1, len(data['contexts']) + 1)] # delete from Nomic response = self.sqlDb.getProjectsMapForCourse(course_name) - if not response.data: + data, count = response + if not data: raise Exception(f"No document map found for this course: {course_name}") - project_id = response.data[0]['doc_map_id'] - self.nomicService.delete_from_document_map(project_id, nomic_ids_to_delete) + project_id = data[0]['doc_map_id'] + if self.nomicService is not None: + self.nomicService.delete_from_document_map(project_id, nomic_ids_to_delete) except Exception as e: print(f"Nomic Error in deleting. {identifier_key}: {identifier_value}", e) - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) try: print(f"Supabase Delete. course: {course_name} using {identifier_key}: {identifier_value}") response = self.sqlDb.deleteMaterialsForCourseAndKeyAndValue(course_name, identifier_key, identifier_value) except Exception as e: print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) def vector_search(self, search_query, course_name, doc_groups: List[str] | None = None): """ @@ -375,14 +384,15 @@ def _embed_query_and_measure_latency(self, search_query): return user_query_embedding def _capture_search_invoked_event(self, search_query, course_name, doc_groups): - self.posthog.capture( - event_name="vector_search_invoked", - properties={ - "user_query": search_query, - "course_name": course_name, - "doc_groups": doc_groups, - }, - ) + if self.posthog is not None: + self.posthog.capture( + event_name="vector_search_invoked", + properties={ + "user_query": search_query, + "course_name": course_name, + "doc_groups": doc_groups, + }, + ) def _perform_vector_search(self, search_query, course_name, doc_groups, user_query_embedding, top_n): qdrant_start_time = time.monotonic() @@ -403,25 +413,27 @@ def _process_search_results(self, search_results, course_name): found_docs.append(Document(page_content=page_content, metadata=metadata)) except Exception as e: print(f"Error in vector_search(), for course: `{course_name}`. Error: {e}") - self.sentry.capture_exception(e) + if self.sentry is not None: + self.sentry.capture_exception(e) return found_docs def _capture_search_succeeded_event(self, search_query, course_name, search_results): vector_score_calc_latency_sec = time.monotonic() max_vector_score, min_vector_score, avg_vector_score = self._calculate_vector_scores(search_results) - self.posthog.capture( - event_name="vector_search_succeeded", - properties={ - "user_query": search_query, - "course_name": course_name, - "qdrant_latency_sec": self.qdrant_latency_sec, - "openai_embedding_latency_sec": self.openai_embedding_latency, - "max_vector_score": max_vector_score, - "min_vector_score": min_vector_score, - "avg_vector_score": avg_vector_score, - "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, - }, - ) + if self.posthog is not None: + self.posthog.capture( + event_name="vector_search_succeeded", + properties={ + "user_query": search_query, + "course_name": course_name, + "qdrant_latency_sec": self.qdrant_latency_sec, + "openai_embedding_latency_sec": self.openai_embedding_latency, + "max_vector_score": max_vector_score, + "min_vector_score": min_vector_score, + "avg_vector_score": avg_vector_score, + "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, + }, + ) def _calculate_vector_scores(self, search_results): max_vector_score = 0 diff --git a/ai_ta_backend/service/workflow_service.py b/ai_ta_backend/service/workflow_service.py index 1afaeda7..badc38ad 100644 --- a/ai_ta_backend/service/workflow_service.py +++ b/ai_ta_backend/service/workflow_service.py @@ -5,7 +5,7 @@ from urllib.parse import quote import json from injector import inject -from ai_ta_backend.database.sql import SQLDatabase +from ai_ta_backend.database.database_impl.sql.supabase import SQLDatabase class WorkflowService: From 1093fb63e6c326f2a380339d3ea1d180911ac480 Mon Sep 17 00:00:00 2001 From: rohanmarwaha Date: Fri, 24 May 2024 19:14:01 -0500 Subject: [PATCH 02/11] Major revamp adding Flask-SqlAlchemy support --- ai_ta_backend/beam/OpenaiEmbeddings.py | 1100 +++---- ai_ta_backend/beam/ingest.py | 2740 ++++++++--------- ai_ta_backend/beam/nomic_logging.py | 826 ++--- .../{database_impl/storage => }/aws.py | 4 +- ai_ta_backend/database/base_sql.py | 69 - ai_ta_backend/database/base_storage.py | 15 - ai_ta_backend/database/base_vector.py | 12 - .../database/database_impl/__init__.py | 0 .../{database_impl/vector => }/qdrant.py | 3 +- ai_ta_backend/database/sql.py | 230 ++ .../{database_impl/sql => }/supabase.py | 5 +- ai_ta_backend/extensions.py | 2 + ai_ta_backend/main.py | 71 +- ai_ta_backend/model/models.py | 89 + ai_ta_backend/model/response.py | 9 + ai_ta_backend/service/export_service.py | 49 +- ai_ta_backend/service/nomic_service.py | 24 +- ai_ta_backend/service/retrieval_service.py | 17 +- ai_ta_backend/service/workflow_service.py | 2 +- requirements.txt | 2 +- 20 files changed, 2766 insertions(+), 2503 deletions(-) rename ai_ta_backend/database/{database_impl/storage => }/aws.py (91%) delete mode 100644 ai_ta_backend/database/base_sql.py delete mode 100644 ai_ta_backend/database/base_storage.py delete mode 100644 ai_ta_backend/database/base_vector.py delete mode 100644 ai_ta_backend/database/database_impl/__init__.py rename ai_ta_backend/database/{database_impl/vector => }/qdrant.py (96%) create mode 100644 ai_ta_backend/database/sql.py rename ai_ta_backend/database/{database_impl/sql => }/supabase.py (98%) create mode 100644 ai_ta_backend/extensions.py create mode 100644 ai_ta_backend/model/models.py create mode 100644 ai_ta_backend/model/response.py diff --git a/ai_ta_backend/beam/OpenaiEmbeddings.py b/ai_ta_backend/beam/OpenaiEmbeddings.py index 2f0f64f7..6c8239f2 100644 --- a/ai_ta_backend/beam/OpenaiEmbeddings.py +++ b/ai_ta_backend/beam/OpenaiEmbeddings.py @@ -1,550 +1,550 @@ -""" -API REQUEST PARALLEL PROCESSOR - -Using the OpenAI API to process lots of text quickly takes some care. -If you trickle in a million API requests one by one, they'll take days to complete. -If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors. -To maximize throughput, parallel requests need to be throttled to stay under rate limits. - -This script parallelizes requests to the OpenAI API while throttling to stay under rate limits. - -Features: -- Streams requests from file, to avoid running out of memory for giant jobs -- Makes requests concurrently, to maximize throughput -- Throttles request and token usage, to stay under rate limits -- Retries failed requests up to {max_attempts} times, to avoid missing data -- Logs errors, to diagnose problems with requests - -Example command to call script: -``` -python examples/api_request_parallel_processor.py \ - --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \ - --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \ - --request_url https://api.openai.com/v1/embeddings \ - --max_requests_per_minute 1500 \ - --max_tokens_per_minute 6250000 \ - --token_encoding_name cl100k_base \ - --max_attempts 5 \ - --logging_level 20 -``` - -Inputs: -- requests_filepath : str - - path to the file containing the requests to be processed - - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field - - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}} - - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically) - - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl - - the code to generate the example file is appended to the bottom of this script -- save_filepath : str, optional - - path to the file where the results will be saved - - file will be a jsonl file, where each line is an array with the original request plus the API response - - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}] - - if omitted, results will be saved to {requests_filename}_results.jsonl -- request_url : str, optional - - URL of the API endpoint to call - - if omitted, will default to "https://api.openai.com/v1/embeddings" -- api_key : str, optional - - API key to use - - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")} -- max_requests_per_minute : float, optional - - target number of requests to make per minute (will make less if limited by tokens) - - leave headroom by setting this to 50% or 75% of your limit - - if requests are limiting you, try batching multiple embeddings or completions into one request - - if omitted, will default to 1,500 -- max_tokens_per_minute : float, optional - - target number of tokens to use per minute (will use less if limited by requests) - - leave headroom by setting this to 50% or 75% of your limit - - if omitted, will default to 125,000 -- token_encoding_name : str, optional - - name of the token encoding used, as defined in the `tiktoken` package - - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`) -- max_attempts : int, optional - - number of times to retry a failed request before giving up - - if omitted, will default to 5 -- logging_level : int, optional - - level of logging to use; higher numbers will log fewer messages - - 40 = ERROR; will log only when requests fail after all retries - - 30 = WARNING; will log when requests his rate limits or other errors - - 20 = INFO; will log when requests start and the status at finish - - 10 = DEBUG; will log various things as the loop runs to see when they occur - - if omitted, will default to 20 (INFO). - -The script is structured as follows: - - Imports - - Define main() - - Initialize things - - In main loop: - - Get next request if one is not already waiting for capacity - - Update available token & request capacity - - If enough capacity available, call API - - The loop pauses if a rate limit error is hit - - The loop breaks when no tasks remain - - Define dataclasses - - StatusTracker (stores script metadata counters; only one instance is created) - - APIRequest (stores API inputs, outputs, metadata; one method to call API) - - Define functions - - api_endpoint_from_url (extracts API endpoint from request URL) - - append_to_jsonl (writes to results file) - - num_tokens_consumed_from_request (bigger function to infer token usage from request) - - task_id_generator_function (yields 1, 2, 3, ...) - - Run main() -""" - -# import argparse -# import subprocess -# import tempfile -# from langchain.llms import OpenAI -import asyncio -import json -import logging - -# import os -import re -import time - -# for storing API inputs, outputs, and metadata -from dataclasses import dataclass, field -from typing import Any, List - -import aiohttp # for making API calls concurrently -import tiktoken # for counting tokens - -# from langchain.embeddings.openai import OpenAIEmbeddings -# from langchain.vectorstores import Qdrant -# from qdrant_client import QdrantClient, models - - -class OpenAIAPIProcessor: - - def __init__(self, input_prompts_list, request_url, api_key, max_requests_per_minute, max_tokens_per_minute, - token_encoding_name, max_attempts, logging_level): - self.request_url = request_url - self.api_key = api_key - self.max_requests_per_minute = max_requests_per_minute - self.max_tokens_per_minute = max_tokens_per_minute - self.token_encoding_name = token_encoding_name - self.max_attempts = max_attempts - self.logging_level = logging_level - self.input_prompts_list: List[dict] = input_prompts_list - self.results = [] - self.cleaned_results: List[str] = [] - - async def process_api_requests_from_file(self): - """Processes API requests in parallel, throttling to stay under rate limits.""" - # constants - seconds_to_pause_after_rate_limit_error = 15 - seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second - - # initialize logging - logging.basicConfig(level=self.logging_level) - logging.debug(f"Logging initialized at level {self.logging_level}") - - # infer API endpoint and construct request header - api_endpoint = api_endpoint_from_url(self.request_url) - request_header = {"Authorization": f"Bearer {self.api_key}"} - - # initialize trackers - queue_of_requests_to_retry = asyncio.Queue() - task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... - status_tracker = StatusTracker() # single instance to track a collection of variables - next_request = None # variable to hold the next request to call - - # initialize available capacity counts - available_request_capacity = self.max_requests_per_minute - available_token_capacity = self.max_tokens_per_minute - last_update_time = time.time() - - # initialize flags - file_not_finished = True # after file is empty, we'll skip reading it - logging.debug("Initialization complete.") - - requests = self.input_prompts_list.__iter__() - - logging.debug("File opened. Entering main loop") - - task_list = [] - - while True: - # get next request (if one is not already waiting for capacity) - if next_request is None: - if not queue_of_requests_to_retry.empty(): - next_request = queue_of_requests_to_retry.get_nowait() - logging.debug(f"Retrying request {next_request.task_id}: {next_request}") - elif file_not_finished: - try: - # get new request - # request_json = json.loads(next(requests)) - request_json = next(requests) - - next_request = APIRequest(task_id=next(task_id_generator), - request_json=request_json, - token_consumption=num_tokens_consumed_from_request( - request_json, api_endpoint, self.token_encoding_name), - attempts_left=self.max_attempts, - metadata=request_json.pop("metadata", None)) - status_tracker.num_tasks_started += 1 - status_tracker.num_tasks_in_progress += 1 - logging.debug(f"Reading request {next_request.task_id}: {next_request}") - except StopIteration: - # if file runs out, set flag to stop reading it - logging.debug("Read file exhausted") - file_not_finished = False - - # update available capacity - current_time = time.time() - seconds_since_update = current_time - last_update_time - available_request_capacity = min( - available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0, - self.max_requests_per_minute, - ) - available_token_capacity = min( - available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0, - self.max_tokens_per_minute, - ) - last_update_time = current_time - - # if enough capacity available, call API - if next_request: - next_request_tokens = next_request.token_consumption - if (available_request_capacity >= 1 and available_token_capacity >= next_request_tokens): - # update counters - available_request_capacity -= 1 - available_token_capacity -= next_request_tokens - next_request.attempts_left -= 1 - - # call API - # TODO: NOT SURE RESPONSE WILL WORK HERE - task = asyncio.create_task( - next_request.call_api( - request_url=self.request_url, - request_header=request_header, - retry_queue=queue_of_requests_to_retry, - status_tracker=status_tracker, - )) - task_list.append(task) - next_request = None # reset next_request to empty - - # print("status_tracker.num_tasks_in_progress", status_tracker.num_tasks_in_progress) - # one_task_result = task.result() - # print("one_task_result", one_task_result) - - # if all tasks are finished, break - if status_tracker.num_tasks_in_progress == 0: - break - - # main loop sleeps briefly so concurrent tasks can run - await asyncio.sleep(seconds_to_sleep_each_loop) - - # if a rate limit error was hit recently, pause to cool down - seconds_since_rate_limit_error = (time.time() - status_tracker.time_of_last_rate_limit_error) - if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: - remaining_seconds_to_pause = (seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error) - await asyncio.sleep(remaining_seconds_to_pause) - # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago - logging.warn( - f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" - ) - - # after finishing, log final status - logging.info("""Parallel processing complete. About to return.""") - if status_tracker.num_tasks_failed > 0: - logging.warning(f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed.") - if status_tracker.num_rate_limit_errors > 0: - logging.warning( - f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") - - # asyncio wait for task_list - await asyncio.wait(task_list) - - for task in task_list: - openai_completion = task.result() - self.results.append(openai_completion) - - self.cleaned_results: List[str] = extract_context_from_results(self.results) - - -def extract_context_from_results(results: List[Any]) -> List[str]: - assistant_contents = [] - total_prompt_tokens = 0 - total_completion_tokens = 0 - - for element in results: - if element is not None: - for item in element: - if 'choices' in item: - for choice in item['choices']: - if choice['message']['role'] == 'assistant': - assistant_contents.append(choice['message']['content']) - total_prompt_tokens += item['usage']['prompt_tokens'] - total_completion_tokens += item['usage']['completion_tokens'] - # Note: I don't think the prompt_tokens or completion_tokens is working quite right... - - return assistant_contents - - -# dataclasses - - -@dataclass -class StatusTracker: - """Stores metadata about the script's progress. Only one instance is created.""" - - num_tasks_started: int = 0 - num_tasks_in_progress: int = 0 # script ends when this reaches 0 - num_tasks_succeeded: int = 0 - num_tasks_failed: int = 0 - num_rate_limit_errors: int = 0 - num_api_errors: int = 0 # excluding rate limit errors, counted above - num_other_errors: int = 0 - time_of_last_rate_limit_error: float = 0 # used to cool off after hitting rate limits - - -@dataclass -class APIRequest: - """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" - - task_id: int - request_json: dict - token_consumption: int - attempts_left: int - metadata: dict - result: list = field(default_factory=list) - - async def call_api( - self, - request_url: str, - request_header: dict, - retry_queue: asyncio.Queue, - status_tracker: StatusTracker, - ): - """Calls the OpenAI API and saves results.""" - # logging.info(f"Starting request #{self.task_id}") - error = None - try: - async with aiohttp.ClientSession() as session: - async with session.post(url=request_url, headers=request_header, json=self.request_json) as response: - response = await response.json() - if "error" in response: - logging.warning(f"Request {self.task_id} failed with error {response['error']}") - status_tracker.num_api_errors += 1 - error = response - if "Rate limit" in response["error"].get("message", ""): - status_tracker.time_of_last_rate_limit_error = time.time() - status_tracker.num_rate_limit_errors += 1 - status_tracker.num_api_errors -= 1 # rate limit errors are counted separately - - except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them - logging.warning(f"Request {self.task_id} failed with Exception {e}") - status_tracker.num_other_errors += 1 - error = e - if error: - self.result.append(error) - if self.attempts_left: - retry_queue.put_nowait(self) - else: - logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") - data = ([self.request_json, [str(e) for e in self.result], self.metadata] - if self.metadata else [self.request_json, [str(e) for e in self.result]]) - #append_to_jsonl(data, save_filepath) - status_tracker.num_tasks_in_progress -= 1 - status_tracker.num_tasks_failed += 1 - return data - else: - data = ([self.request_json, response, self.metadata] if self.metadata else [self.request_json, response] - ) # type: ignore - #append_to_jsonl(data, save_filepath) - status_tracker.num_tasks_in_progress -= 1 - status_tracker.num_tasks_succeeded += 1 - # logging.debug(f"Request {self.task_id} saved to {save_filepath}") - - return data - - -# functions - - -def api_endpoint_from_url(request_url: str): - """Extract the API endpoint from the request URL.""" - if 'text-embedding-ada-002' in request_url: - return 'embeddings' - else: - match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) - return match[1] # type: ignore - - -def append_to_jsonl(data, filename: str) -> None: - """Append a json payload to the end of a jsonl file.""" - json_string = json.dumps(data) - with open(filename, "a") as f: - f.write(json_string + "\n") - - -def num_tokens_consumed_from_request( - request_json: dict, - api_endpoint: str, - token_encoding_name: str, -): - """Count the number of tokens in the request. Only supports completion and embedding requests.""" - encoding = tiktoken.get_encoding(token_encoding_name) - # if completions request, tokens = prompt + n * max_tokens - if api_endpoint.endswith("completions"): - max_tokens = request_json.get("max_tokens", 15) - n = request_json.get("n", 1) - completion_tokens = n * max_tokens - - # chat completions - if api_endpoint.startswith("chat/"): - num_tokens = 0 - for message in request_json["messages"]: - num_tokens += 4 # every message follows {role/name}\n{content}\n - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": # if there's a name, the role is omitted - num_tokens -= 1 # role is always required and always 1 token - num_tokens += 2 # every reply is primed with assistant - return num_tokens + completion_tokens - # normal completions - else: - prompt = request_json["prompt"] - if isinstance(prompt, str): # single prompt - prompt_tokens = len(encoding.encode(prompt)) - num_tokens = prompt_tokens + completion_tokens - return num_tokens - elif isinstance(prompt, list): # multiple prompts - prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) - num_tokens = prompt_tokens + completion_tokens * len(prompt) - return num_tokens - else: - raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') - # if embeddings request, tokens = input tokens - elif api_endpoint == "embeddings": - input = request_json["input"] - if isinstance(input, str): # single input - num_tokens = len(encoding.encode(input)) - return num_tokens - elif isinstance(input, list): # multiple inputs - num_tokens = sum([len(encoding.encode(i)) for i in input]) - return num_tokens - else: - raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') - # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) - else: - raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') - - -def task_id_generator_function(): - """Generate integers 0, 1, 2, and so on.""" - task_id = 0 - while True: - yield task_id - task_id += 1 - - -if __name__ == '__main__': - pass - - # run script - # if __name__ == "__main__": - # qdrant_client = QdrantClient( - # url=os.getenv('QDRANT_URL'), - # api_key=os.getenv('QDRANT_API_KEY'), - # ) - # vectorstore = Qdrant( - # client=qdrant_client, - # collection_name=os.getenv('QDRANT_COLLECTION_NAME'), # type: ignore - # embeddings=OpenAIEmbeddings()) # type: ignore - - # user_question = "What is the significance of Six Sigma?" - # k = 4 - # fetch_k = 200 - # found_docs = vectorstore.max_marginal_relevance_search(user_question, k=k, fetch_k=200) - - # requests = [] - # for i, doc in enumerate(found_docs): - # dictionary = { - # "model": "gpt-3.5-turbo-0613", # 4k context - # "messages": [{ - # "role": "system", - # "content": "You are a factual summarizer of partial documents. Stick to the facts (including partial info when necessary to avoid making up potentially incorrect details), and say I don't know when necessary." - # }, { - # "role": - # "user", - # "content": - # f"What is a comprehensive summary of the given text, based on the question:\n{doc.page_content}\nQuestion: {user_question}\nThe summary should cover all the key points only relevant to the question, while also condensing the information into a concise and easy-to-understand format. Please ensure that the summary includes relevant details and examples that support the main ideas, while avoiding any unnecessary information or repetition. Feel free to include references, sentence fragments, keywords, or anything that could help someone learn about it, only as it relates to the given question. The length of the summary should be as short as possible, without losing relevant information.\n" - # }], - # "n": 1, - # "max_tokens": 500, - # "metadata": doc.metadata - # } - # requests.append(dictionary) - - # oai = OpenAIAPIProcessor( - # input_prompts_list=requests, - # request_url='https://api.openai.com/v1/chat/completions', - # api_key=os.getenv("OPENAI_API_KEY"), - # max_requests_per_minute=1500, - # max_tokens_per_minute=90000, - # token_encoding_name='cl100k_base', - # max_attempts=5, - # logging_level=20, - # ) - # # run script - # asyncio.run(oai.process_api_requests_from_file()) - - # assistant_contents = [] - # total_prompt_tokens = 0 - # total_completion_tokens = 0 - - # print("Results, end of main: ", oai.results) - # print("-"*50) - - # # jsonObject = json.loads(oai.results) - # for element in oai.results: - # for item in element: - # if 'choices' in item: - # for choice in item['choices']: - # if choice['message']['role'] == 'assistant': - # assistant_contents.append(choice['message']['content']) - # total_prompt_tokens += item['usage']['prompt_tokens'] - # total_completion_tokens += item['usage']['completion_tokens'] - - # print("Assistant Contents:", assistant_contents) - # print("Total Prompt Tokens:", total_prompt_tokens) - # print("Total Completion Tokens:", total_completion_tokens) - # turbo_total_cost = (total_prompt_tokens * 0.0015) + (total_completion_tokens * 0.002) - # print("Total cost (3.5-turbo):", (total_prompt_tokens * 0.0015), " + Completions: ", (total_completion_tokens * 0.002), " = ", turbo_total_cost) - - # gpt4_total_cost = (total_prompt_tokens * 0.03) + (total_completion_tokens * 0.06) - # print("Hypothetical cost for GPT-4:", (total_prompt_tokens * 0.03), " + Completions: ", (total_completion_tokens * 0.06), " = ", gpt4_total_cost) - # print("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") - ''' - Pricing: - GPT4: - * $0.03 prompt - * $0.06 completions - 3.5-turbo: - * $0.0015 prompt - * $0.002 completions - ''' -""" -APPENDIX - -The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002. - -It was generated with the following code: - -```python -import json - -filename = "data/example_requests_to_parallel_process.jsonl" -n_requests = 10_000 -jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)] -with open(filename, "w") as f: - for job in jobs: - json_string = json.dumps(job) - f.write(json_string + "\n") -``` - -As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically). -""" +# """ +# API REQUEST PARALLEL PROCESSOR + +# Using the OpenAI API to process lots of text quickly takes some care. +# If you trickle in a million API requests one by one, they'll take days to complete. +# If you flood a million API requests in parallel, they'll exceed the rate limits and fail with errors. +# To maximize throughput, parallel requests need to be throttled to stay under rate limits. + +# This script parallelizes requests to the OpenAI API while throttling to stay under rate limits. + +# Features: +# - Streams requests from file, to avoid running out of memory for giant jobs +# - Makes requests concurrently, to maximize throughput +# - Throttles request and token usage, to stay under rate limits +# - Retries failed requests up to {max_attempts} times, to avoid missing data +# - Logs errors, to diagnose problems with requests + +# Example command to call script: +# ``` +# python examples/api_request_parallel_processor.py \ +# --requests_filepath examples/data/example_requests_to_parallel_process.jsonl \ +# --save_filepath examples/data/example_requests_to_parallel_process_results.jsonl \ +# --request_url https://api.openai.com/v1/embeddings \ +# --max_requests_per_minute 1500 \ +# --max_tokens_per_minute 6250000 \ +# --token_encoding_name cl100k_base \ +# --max_attempts 5 \ +# --logging_level 20 +# ``` + +# Inputs: +# - requests_filepath : str +# - path to the file containing the requests to be processed +# - file should be a jsonl file, where each line is a json object with API parameters and an optional metadata field +# - e.g., {"model": "text-embedding-ada-002", "input": "embed me", "metadata": {"row_id": 1}} +# - as with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically) +# - an example file is provided at examples/data/example_requests_to_parallel_process.jsonl +# - the code to generate the example file is appended to the bottom of this script +# - save_filepath : str, optional +# - path to the file where the results will be saved +# - file will be a jsonl file, where each line is an array with the original request plus the API response +# - e.g., [{"model": "text-embedding-ada-002", "input": "embed me"}, {...}] +# - if omitted, results will be saved to {requests_filename}_results.jsonl +# - request_url : str, optional +# - URL of the API endpoint to call +# - if omitted, will default to "https://api.openai.com/v1/embeddings" +# - api_key : str, optional +# - API key to use +# - if omitted, the script will attempt to read it from an environment variable {os.getenv("OPENAI_API_KEY")} +# - max_requests_per_minute : float, optional +# - target number of requests to make per minute (will make less if limited by tokens) +# - leave headroom by setting this to 50% or 75% of your limit +# - if requests are limiting you, try batching multiple embeddings or completions into one request +# - if omitted, will default to 1,500 +# - max_tokens_per_minute : float, optional +# - target number of tokens to use per minute (will use less if limited by requests) +# - leave headroom by setting this to 50% or 75% of your limit +# - if omitted, will default to 125,000 +# - token_encoding_name : str, optional +# - name of the token encoding used, as defined in the `tiktoken` package +# - if omitted, will default to "cl100k_base" (used by `text-embedding-ada-002`) +# - max_attempts : int, optional +# - number of times to retry a failed request before giving up +# - if omitted, will default to 5 +# - logging_level : int, optional +# - level of logging to use; higher numbers will log fewer messages +# - 40 = ERROR; will log only when requests fail after all retries +# - 30 = WARNING; will log when requests his rate limits or other errors +# - 20 = INFO; will log when requests start and the status at finish +# - 10 = DEBUG; will log various things as the loop runs to see when they occur +# - if omitted, will default to 20 (INFO). + +# The script is structured as follows: +# - Imports +# - Define main() +# - Initialize things +# - In main loop: +# - Get next request if one is not already waiting for capacity +# - Update available token & request capacity +# - If enough capacity available, call API +# - The loop pauses if a rate limit error is hit +# - The loop breaks when no tasks remain +# - Define dataclasses +# - StatusTracker (stores script metadata counters; only one instance is created) +# - APIRequest (stores API inputs, outputs, metadata; one method to call API) +# - Define functions +# - api_endpoint_from_url (extracts API endpoint from request URL) +# - append_to_jsonl (writes to results file) +# - num_tokens_consumed_from_request (bigger function to infer token usage from request) +# - task_id_generator_function (yields 1, 2, 3, ...) +# - Run main() +# """ + +# # import argparse +# # import subprocess +# # import tempfile +# # from langchain.llms import OpenAI +# import asyncio +# import json +# import logging + +# # import os +# import re +# import time + +# # for storing API inputs, outputs, and metadata +# from dataclasses import dataclass, field +# from typing import Any, List + +# import aiohttp # for making API calls concurrently +# import tiktoken # for counting tokens + +# # from langchain.embeddings.openai import OpenAIEmbeddings +# # from langchain.vectorstores import Qdrant +# # from qdrant_client import QdrantClient, models + + +# class OpenAIAPIProcessor: + +# def __init__(self, input_prompts_list, request_url, api_key, max_requests_per_minute, max_tokens_per_minute, +# token_encoding_name, max_attempts, logging_level): +# self.request_url = request_url +# self.api_key = api_key +# self.max_requests_per_minute = max_requests_per_minute +# self.max_tokens_per_minute = max_tokens_per_minute +# self.token_encoding_name = token_encoding_name +# self.max_attempts = max_attempts +# self.logging_level = logging_level +# self.input_prompts_list: List[dict] = input_prompts_list +# self.results = [] +# self.cleaned_results: List[str] = [] + +# async def process_api_requests_from_file(self): +# """Processes API requests in parallel, throttling to stay under rate limits.""" +# # constants +# seconds_to_pause_after_rate_limit_error = 15 +# seconds_to_sleep_each_loop = 0.001 # 1 ms limits max throughput to 1,000 requests per second + +# # initialize logging +# logging.basicConfig(level=self.logging_level) +# logging.debug(f"Logging initialized at level {self.logging_level}") + +# # infer API endpoint and construct request header +# api_endpoint = api_endpoint_from_url(self.request_url) +# request_header = {"Authorization": f"Bearer {self.api_key}"} + +# # initialize trackers +# queue_of_requests_to_retry = asyncio.Queue() +# task_id_generator = task_id_generator_function() # generates integer IDs of 1, 2, 3, ... +# status_tracker = StatusTracker() # single instance to track a collection of variables +# next_request = None # variable to hold the next request to call + +# # initialize available capacity counts +# available_request_capacity = self.max_requests_per_minute +# available_token_capacity = self.max_tokens_per_minute +# last_update_time = time.time() + +# # initialize flags +# file_not_finished = True # after file is empty, we'll skip reading it +# logging.debug("Initialization complete.") + +# requests = self.input_prompts_list.__iter__() + +# logging.debug("File opened. Entering main loop") + +# task_list = [] + +# while True: +# # get next request (if one is not already waiting for capacity) +# if next_request is None: +# if not queue_of_requests_to_retry.empty(): +# next_request = queue_of_requests_to_retry.get_nowait() +# logging.debug(f"Retrying request {next_request.task_id}: {next_request}") +# elif file_not_finished: +# try: +# # get new request +# # request_json = json.loads(next(requests)) +# request_json = next(requests) + +# next_request = APIRequest(task_id=next(task_id_generator), +# request_json=request_json, +# token_consumption=num_tokens_consumed_from_request( +# request_json, api_endpoint, self.token_encoding_name), +# attempts_left=self.max_attempts, +# metadata=request_json.pop("metadata", None)) +# status_tracker.num_tasks_started += 1 +# status_tracker.num_tasks_in_progress += 1 +# logging.debug(f"Reading request {next_request.task_id}: {next_request}") +# except StopIteration: +# # if file runs out, set flag to stop reading it +# logging.debug("Read file exhausted") +# file_not_finished = False + +# # update available capacity +# current_time = time.time() +# seconds_since_update = current_time - last_update_time +# available_request_capacity = min( +# available_request_capacity + self.max_requests_per_minute * seconds_since_update / 60.0, +# self.max_requests_per_minute, +# ) +# available_token_capacity = min( +# available_token_capacity + self.max_tokens_per_minute * seconds_since_update / 60.0, +# self.max_tokens_per_minute, +# ) +# last_update_time = current_time + +# # if enough capacity available, call API +# if next_request: +# next_request_tokens = next_request.token_consumption +# if (available_request_capacity >= 1 and available_token_capacity >= next_request_tokens): +# # update counters +# available_request_capacity -= 1 +# available_token_capacity -= next_request_tokens +# next_request.attempts_left -= 1 + +# # call API +# # TODO: NOT SURE RESPONSE WILL WORK HERE +# task = asyncio.create_task( +# next_request.call_api( +# request_url=self.request_url, +# request_header=request_header, +# retry_queue=queue_of_requests_to_retry, +# status_tracker=status_tracker, +# )) +# task_list.append(task) +# next_request = None # reset next_request to empty + +# # print("status_tracker.num_tasks_in_progress", status_tracker.num_tasks_in_progress) +# # one_task_result = task.result() +# # print("one_task_result", one_task_result) + +# # if all tasks are finished, break +# if status_tracker.num_tasks_in_progress == 0: +# break + +# # main loop sleeps briefly so concurrent tasks can run +# await asyncio.sleep(seconds_to_sleep_each_loop) + +# # if a rate limit error was hit recently, pause to cool down +# seconds_since_rate_limit_error = (time.time() - status_tracker.time_of_last_rate_limit_error) +# if seconds_since_rate_limit_error < seconds_to_pause_after_rate_limit_error: +# remaining_seconds_to_pause = (seconds_to_pause_after_rate_limit_error - seconds_since_rate_limit_error) +# await asyncio.sleep(remaining_seconds_to_pause) +# # ^e.g., if pause is 15 seconds and final limit was hit 5 seconds ago +# logging.warn( +# f"Pausing to cool down until {time.ctime(status_tracker.time_of_last_rate_limit_error + seconds_to_pause_after_rate_limit_error)}" +# ) + +# # after finishing, log final status +# logging.info("""Parallel processing complete. About to return.""") +# if status_tracker.num_tasks_failed > 0: +# logging.warning(f"{status_tracker.num_tasks_failed} / {status_tracker.num_tasks_started} requests failed.") +# if status_tracker.num_rate_limit_errors > 0: +# logging.warning( +# f"{status_tracker.num_rate_limit_errors} rate limit errors received. Consider running at a lower rate.") + +# # asyncio wait for task_list +# await asyncio.wait(task_list) + +# for task in task_list: +# openai_completion = task.result() +# self.results.append(openai_completion) + +# self.cleaned_results: List[str] = extract_context_from_results(self.results) + + +# def extract_context_from_results(results: List[Any]) -> List[str]: +# assistant_contents = [] +# total_prompt_tokens = 0 +# total_completion_tokens = 0 + +# for element in results: +# if element is not None: +# for item in element: +# if 'choices' in item: +# for choice in item['choices']: +# if choice['message']['role'] == 'assistant': +# assistant_contents.append(choice['message']['content']) +# total_prompt_tokens += item['usage']['prompt_tokens'] +# total_completion_tokens += item['usage']['completion_tokens'] +# # Note: I don't think the prompt_tokens or completion_tokens is working quite right... + +# return assistant_contents + + +# # dataclasses + + +# @dataclass +# class StatusTracker: +# """Stores metadata about the script's progress. Only one instance is created.""" + +# num_tasks_started: int = 0 +# num_tasks_in_progress: int = 0 # script ends when this reaches 0 +# num_tasks_succeeded: int = 0 +# num_tasks_failed: int = 0 +# num_rate_limit_errors: int = 0 +# num_api_errors: int = 0 # excluding rate limit errors, counted above +# num_other_errors: int = 0 +# time_of_last_rate_limit_error: float = 0 # used to cool off after hitting rate limits + + +# @dataclass +# class APIRequest: +# """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" + +# task_id: int +# request_json: dict +# token_consumption: int +# attempts_left: int +# metadata: dict +# result: list = field(default_factory=list) + +# async def call_api( +# self, +# request_url: str, +# request_header: dict, +# retry_queue: asyncio.Queue, +# status_tracker: StatusTracker, +# ): +# """Calls the OpenAI API and saves results.""" +# # logging.info(f"Starting request #{self.task_id}") +# error = None +# try: +# async with aiohttp.ClientSession() as session: +# async with session.post(url=request_url, headers=request_header, json=self.request_json) as response: +# response = await response.json() +# if "error" in response: +# logging.warning(f"Request {self.task_id} failed with error {response['error']}") +# status_tracker.num_api_errors += 1 +# error = response +# if "Rate limit" in response["error"].get("message", ""): +# status_tracker.time_of_last_rate_limit_error = time.time() +# status_tracker.num_rate_limit_errors += 1 +# status_tracker.num_api_errors -= 1 # rate limit errors are counted separately + +# except Exception as e: # catching naked exceptions is bad practice, but in this case we'll log & save them +# logging.warning(f"Request {self.task_id} failed with Exception {e}") +# status_tracker.num_other_errors += 1 +# error = e +# if error: +# self.result.append(error) +# if self.attempts_left: +# retry_queue.put_nowait(self) +# else: +# logging.error(f"Request {self.request_json} failed after all attempts. Saving errors: {self.result}") +# data = ([self.request_json, [str(e) for e in self.result], self.metadata] +# if self.metadata else [self.request_json, [str(e) for e in self.result]]) +# #append_to_jsonl(data, save_filepath) +# status_tracker.num_tasks_in_progress -= 1 +# status_tracker.num_tasks_failed += 1 +# return data +# else: +# data = ([self.request_json, response, self.metadata] if self.metadata else [self.request_json, response] +# ) # type: ignore +# #append_to_jsonl(data, save_filepath) +# status_tracker.num_tasks_in_progress -= 1 +# status_tracker.num_tasks_succeeded += 1 +# # logging.debug(f"Request {self.task_id} saved to {save_filepath}") + +# return data + + +# # functions + + +# def api_endpoint_from_url(request_url: str): +# """Extract the API endpoint from the request URL.""" +# if 'text-embedding-ada-002' in request_url: +# return 'embeddings' +# else: +# match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) +# return match[1] # type: ignore + + +# def append_to_jsonl(data, filename: str) -> None: +# """Append a json payload to the end of a jsonl file.""" +# json_string = json.dumps(data) +# with open(filename, "a") as f: +# f.write(json_string + "\n") + + +# def num_tokens_consumed_from_request( +# request_json: dict, +# api_endpoint: str, +# token_encoding_name: str, +# ): +# """Count the number of tokens in the request. Only supports completion and embedding requests.""" +# encoding = tiktoken.get_encoding(token_encoding_name) +# # if completions request, tokens = prompt + n * max_tokens +# if api_endpoint.endswith("completions"): +# max_tokens = request_json.get("max_tokens", 15) +# n = request_json.get("n", 1) +# completion_tokens = n * max_tokens + +# # chat completions +# if api_endpoint.startswith("chat/"): +# num_tokens = 0 +# for message in request_json["messages"]: +# num_tokens += 4 # every message follows {role/name}\n{content}\n +# for key, value in message.items(): +# num_tokens += len(encoding.encode(value)) +# if key == "name": # if there's a name, the role is omitted +# num_tokens -= 1 # role is always required and always 1 token +# num_tokens += 2 # every reply is primed with assistant +# return num_tokens + completion_tokens +# # normal completions +# else: +# prompt = request_json["prompt"] +# if isinstance(prompt, str): # single prompt +# prompt_tokens = len(encoding.encode(prompt)) +# num_tokens = prompt_tokens + completion_tokens +# return num_tokens +# elif isinstance(prompt, list): # multiple prompts +# prompt_tokens = sum([len(encoding.encode(p)) for p in prompt]) +# num_tokens = prompt_tokens + completion_tokens * len(prompt) +# return num_tokens +# else: +# raise TypeError('Expecting either string or list of strings for "prompt" field in completion request') +# # if embeddings request, tokens = input tokens +# elif api_endpoint == "embeddings": +# input = request_json["input"] +# if isinstance(input, str): # single input +# num_tokens = len(encoding.encode(input)) +# return num_tokens +# elif isinstance(input, list): # multiple inputs +# num_tokens = sum([len(encoding.encode(i)) for i in input]) +# return num_tokens +# else: +# raise TypeError('Expecting either string or list of strings for "inputs" field in embedding request') +# # more logic needed to support other API calls (e.g., edits, inserts, DALL-E) +# else: +# raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') + + +# def task_id_generator_function(): +# """Generate integers 0, 1, 2, and so on.""" +# task_id = 0 +# while True: +# yield task_id +# task_id += 1 + + +# if __name__ == '__main__': +# pass + +# # run script +# # if __name__ == "__main__": +# # qdrant_client = QdrantClient( +# # url=os.getenv('QDRANT_URL'), +# # api_key=os.getenv('QDRANT_API_KEY'), +# # ) +# # vectorstore = Qdrant( +# # client=qdrant_client, +# # collection_name=os.getenv('QDRANT_COLLECTION_NAME'), # type: ignore +# # embeddings=OpenAIEmbeddings()) # type: ignore + +# # user_question = "What is the significance of Six Sigma?" +# # k = 4 +# # fetch_k = 200 +# # found_docs = vectorstore.max_marginal_relevance_search(user_question, k=k, fetch_k=200) + +# # requests = [] +# # for i, doc in enumerate(found_docs): +# # dictionary = { +# # "model": "gpt-3.5-turbo-0613", # 4k context +# # "messages": [{ +# # "role": "system", +# # "content": "You are a factual summarizer of partial documents. Stick to the facts (including partial info when necessary to avoid making up potentially incorrect details), and say I don't know when necessary." +# # }, { +# # "role": +# # "user", +# # "content": +# # f"What is a comprehensive summary of the given text, based on the question:\n{doc.page_content}\nQuestion: {user_question}\nThe summary should cover all the key points only relevant to the question, while also condensing the information into a concise and easy-to-understand format. Please ensure that the summary includes relevant details and examples that support the main ideas, while avoiding any unnecessary information or repetition. Feel free to include references, sentence fragments, keywords, or anything that could help someone learn about it, only as it relates to the given question. The length of the summary should be as short as possible, without losing relevant information.\n" +# # }], +# # "n": 1, +# # "max_tokens": 500, +# # "metadata": doc.metadata +# # } +# # requests.append(dictionary) + +# # oai = OpenAIAPIProcessor( +# # input_prompts_list=requests, +# # request_url='https://api.openai.com/v1/chat/completions', +# # api_key=os.getenv("OPENAI_API_KEY"), +# # max_requests_per_minute=1500, +# # max_tokens_per_minute=90000, +# # token_encoding_name='cl100k_base', +# # max_attempts=5, +# # logging_level=20, +# # ) +# # # run script +# # asyncio.run(oai.process_api_requests_from_file()) + +# # assistant_contents = [] +# # total_prompt_tokens = 0 +# # total_completion_tokens = 0 + +# # print("Results, end of main: ", oai.results) +# # print("-"*50) + +# # # jsonObject = json.loads(oai.results) +# # for element in oai.results: +# # for item in element: +# # if 'choices' in item: +# # for choice in item['choices']: +# # if choice['message']['role'] == 'assistant': +# # assistant_contents.append(choice['message']['content']) +# # total_prompt_tokens += item['usage']['prompt_tokens'] +# # total_completion_tokens += item['usage']['completion_tokens'] + +# # print("Assistant Contents:", assistant_contents) +# # print("Total Prompt Tokens:", total_prompt_tokens) +# # print("Total Completion Tokens:", total_completion_tokens) +# # turbo_total_cost = (total_prompt_tokens * 0.0015) + (total_completion_tokens * 0.002) +# # print("Total cost (3.5-turbo):", (total_prompt_tokens * 0.0015), " + Completions: ", (total_completion_tokens * 0.002), " = ", turbo_total_cost) + +# # gpt4_total_cost = (total_prompt_tokens * 0.03) + (total_completion_tokens * 0.06) +# # print("Hypothetical cost for GPT-4:", (total_prompt_tokens * 0.03), " + Completions: ", (total_completion_tokens * 0.06), " = ", gpt4_total_cost) +# # print("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") +# ''' +# Pricing: +# GPT4: +# * $0.03 prompt +# * $0.06 completions +# 3.5-turbo: +# * $0.0015 prompt +# * $0.002 completions +# ''' +# """ +# APPENDIX + +# The example requests file at openai-cookbook/examples/data/example_requests_to_parallel_process.jsonl contains 10,000 requests to text-embedding-ada-002. + +# It was generated with the following code: + +# ```python +# import json + +# filename = "data/example_requests_to_parallel_process.jsonl" +# n_requests = 10_000 +# jobs = [{"model": "text-embedding-ada-002", "input": str(x) + "\n"} for x in range(n_requests)] +# with open(filename, "w") as f: +# for job in jobs: +# json_string = json.dumps(job) +# f.write(json_string + "\n") +# ``` + +# As with all jsonl files, take care that newlines in the content are properly escaped (json.dumps does this automatically). +# """ diff --git a/ai_ta_backend/beam/ingest.py b/ai_ta_backend/beam/ingest.py index f292f204..b31bab87 100644 --- a/ai_ta_backend/beam/ingest.py +++ b/ai_ta_backend/beam/ingest.py @@ -1,1371 +1,1371 @@ -""" -To deploy: beam deploy ingest.py --profile caii-ncsa -Use CAII gmail to auth. -""" -import asyncio -import inspect -import json -import logging -import mimetypes -import os -import re -import shutil -import time -import traceback -import uuid -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any, Callable, Dict, List, Optional, Union - -import beam -import boto3 -import fitz -import openai -import pytesseract -import pdfplumber -import sentry_sdk -import supabase -from beam import App, QueueDepthAutoscaler, Runtime # RequestLatencyAutoscaler, -from bs4 import BeautifulSoup -from git.repo import Repo -from langchain.document_loaders import ( - Docx2txtLoader, - GitLoader, - PythonLoader, - TextLoader, - UnstructuredExcelLoader, - UnstructuredPowerPointLoader, -) -from langchain.document_loaders.csv_loader import CSVLoader -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.schema import Document -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.vectorstores import Qdrant -from nomic_logging import delete_from_document_map, log_to_document_map, rebuild_map -from OpenaiEmbeddings import OpenAIAPIProcessor -from PIL import Image -from posthog import Posthog -from pydub import AudioSegment -from qdrant_client import QdrantClient, models -from qdrant_client.models import PointStruct - -# from langchain.schema.output_parser import StrOutputParser -# from langchain.chat_models import AzureChatOpenAI - -requirements = [ - "openai<1.0", - "supabase==2.0.2", - "tiktoken==0.5.1", - "boto3==1.28.79", - "qdrant-client==1.7.3", - "langchain==0.0.331", - "posthog==3.1.0", - "pysrt==1.1.2", - "docx2txt==0.8", - "pydub==0.25.1", - "ffmpeg-python==0.2.0", - "ffprobe==0.5", - "ffmpeg==1.4", - "PyMuPDF==1.23.6", - "pytesseract==0.3.10", # image OCR" - "openpyxl==3.1.2", # excel" - "networkx==3.2.1", # unused part of excel partitioning :(" - "python-pptx==0.6.23", - "unstructured==0.10.29", - "GitPython==3.1.40", - "beautifulsoup4==4.12.2", - "sentry-sdk==1.39.1", - "nomic==2.0.14", - "pdfplumber==0.11.0", # PDF OCR, better performance than Fitz/PyMuPDF in my Gies PDF testing. -] - -# TODO: consider adding workers. They share CPU and memory https://docs.beam.cloud/deployment/autoscaling#worker-use-cases -app = App("ingest", - runtime=Runtime( - cpu=1, - memory="3Gi", - image=beam.Image( - python_version="python3.10", - python_packages=requirements, - commands=["apt-get update && apt-get install -y ffmpeg tesseract-ocr"], - ), - )) - -# MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") -OPENAI_API_TYPE = "azure" # "openai" or "azure" - - -def loader(): - """ - The loader function will run once for each worker that starts up. https://docs.beam.cloud/deployment/loaders - """ - openai.api_key = os.getenv("VLADS_OPENAI_KEY") - - # vector DB - qdrant_client = QdrantClient( - url=os.getenv('QDRANT_URL'), - api_key=os.getenv('QDRANT_API_KEY'), - ) - - vectorstore = Qdrant(client=qdrant_client, - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE, - openai_api_key=os.getenv('VLADS_OPENAI_KEY'))) - - # S3 - s3_client = boto3.client( - 's3', - aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), - aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), - ) - - # Create a Supabase client - supabase_client = supabase.create_client( # type: ignore - supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) - - # llm = AzureChatOpenAI( - # temperature=0, - # deployment_name=os.getenv('AZURE_OPENAI_ENGINE'), #type:ignore - # openai_api_base=os.getenv('AZURE_OPENAI_ENDPOINT'), #type:ignore - # openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore - # openai_api_version=os.getenv('OPENAI_API_VERSION'), #type:ignore - # openai_api_type=OPENAI_API_TYPE) - - posthog = Posthog(sync_mode=True, project_api_key=os.environ['POSTHOG_API_KEY'], host='https://app.posthog.com') - sentry_sdk.init( - dsn="https://examplePublicKey@o0.ingest.sentry.io/0", - - # Enable performance monitoring - enable_tracing=True, - ) - - return qdrant_client, vectorstore, s3_client, supabase_client, posthog - - -# autoscaler = RequestLatencyAutoscaler(desired_latency=30, max_replicas=2) -autoscaler = QueueDepthAutoscaler(max_tasks_per_replica=300, max_replicas=3) - - -# Triggers determine how your app is deployed -# @app.rest_api( -@app.task_queue( - workers=4, - callback_url='https://uiuc-chat-git-ingestprogresstracking-kastanday.vercel.app/api/UIUC-api/ingestTaskCallback', - max_pending_tasks=15_000, - max_retries=3, - timeout=-1, - loader=loader, - autoscaler=autoscaler) -def ingest(**inputs: Dict[str, Any]): - - qdrant_client, vectorstore, s3_client, supabase_client, posthog = inputs["context"] - - course_name: List[str] | str = inputs.get('course_name', '') - s3_paths: List[str] | str = inputs.get('s3_paths', '') - url: List[str] | str | None = inputs.get('url', None) - base_url: List[str] | str | None = inputs.get('base_url', None) - readable_filename: List[str] | str = inputs.get('readable_filename', '') - content: str | None = inputs.get('content', None) # is webtext if content exists - - print( - f"In top of /ingest route. course: {course_name}, s3paths: {s3_paths}, readable_filename: {readable_filename}, base_url: {base_url}, url: {url}, content: {content}" - ) - - ingester = Ingest(qdrant_client, vectorstore, s3_client, supabase_client, posthog) - - def run_ingest(course_name, s3_paths, base_url, url, readable_filename, content): - if content: - return ingester.ingest_single_web_text(course_name, base_url, url, content, readable_filename) - elif readable_filename == '': - return ingester.bulk_ingest(course_name, s3_paths, base_url=base_url, url=url) - else: - return ingester.bulk_ingest(course_name, - s3_paths, - readable_filename=readable_filename, - base_url=base_url, - url=url) - - # First try - success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) - - # retries - num_retires = 5 - for retry_num in range(1, num_retires): - if isinstance(success_fail_dict, str): - print(f"STRING ERROR: {success_fail_dict = }") - success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) - time.sleep(13 * retry_num) # max is 65 - elif success_fail_dict['failure_ingest']: - print(f"Ingest failure -- Retry attempt {retry_num}. File: {success_fail_dict}") - # s3_paths = success_fail_dict['failure_ingest'] # retry only failed paths.... what if this is a URL instead? - success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) - time.sleep(13 * retry_num) # max is 65 - else: - break - - # Final failure / success check - if success_fail_dict['failure_ingest']: - print(f"INGEST FAILURE -- About to send to supabase. success_fail_dict: {success_fail_dict}") - document = { - "course_name": - course_name, - "s3_path": - s3_paths, - "readable_filename": - readable_filename, - "url": - url, - "base_url": - base_url, - "error": - success_fail_dict['failure_ingest']['error'] - if isinstance(success_fail_dict['failure_ingest'], dict) else success_fail_dict['failure_ingest'] - } - response = supabase_client.table('documents_failed').insert(document).execute() # type: ignore - print(f"Supabase ingest failure response: {response}") - else: - # Success case: rebuild nomic document map after all ingests are done - # rebuild_status = rebuild_map(str(course_name), map_type='document') - pass - - print(f"Final success_fail_dict: {success_fail_dict}") - return json.dumps(success_fail_dict) - - -class Ingest(): - - def __init__(self, qdrant_client, vectorstore, s3_client, supabase_client, posthog): - self.qdrant_client = qdrant_client - self.vectorstore = vectorstore - self.s3_client = s3_client - self.supabase_client = supabase_client - self.posthog = posthog - - def bulk_ingest(self, course_name: str, s3_paths: Union[str, List[str]], - **kwargs) -> Dict[str, None | str | Dict[str, str]]: - """ - Bulk ingest a list of s3 paths into the vectorstore, and also into the supabase database. - -> Dict[str, str | Dict[str, str]] - """ - - def _ingest_single(ingest_method: Callable, s3_path, *args, **kwargs): - """Handle running an arbitrary ingest function for an individual file.""" - # RUN INGEST METHOD - ret = ingest_method(s3_path, *args, **kwargs) - if ret == "Success": - success_status['success_ingest'] = str(s3_path) - else: - success_status['failure_ingest'] = {'s3_path': str(s3_path), 'error': str(ret)} - - # 👇👇👇👇 ADD NEW INGEST METHODS HERE 👇👇👇👇🎉 - file_ingest_methods = { - '.html': self._ingest_html, - '.py': self._ingest_single_py, - '.pdf': self._ingest_single_pdf, - '.txt': self._ingest_single_txt, - '.md': self._ingest_single_txt, - '.srt': self._ingest_single_srt, - '.vtt': self._ingest_single_vtt, - '.docx': self._ingest_single_docx, - '.ppt': self._ingest_single_ppt, - '.pptx': self._ingest_single_ppt, - '.xlsx': self._ingest_single_excel, - '.xls': self._ingest_single_excel, - '.csv': self._ingest_single_csv, - '.png': self._ingest_single_image, - '.jpg': self._ingest_single_image, - } - - # Ingest methods via MIME type (more general than filetype) - mimetype_ingest_methods = { - 'video': self._ingest_single_video, - 'audio': self._ingest_single_video, - 'text': self._ingest_single_txt, - 'image': self._ingest_single_image, - } - # 👆👆👆👆 ADD NEW INGEST METHODhe 👆👆👆👆🎉 - - print(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}") - success_status: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} - try: - if isinstance(s3_paths, str): - s3_paths = [s3_paths] - - for s3_path in s3_paths: - file_extension = Path(s3_path).suffix - with NamedTemporaryFile(suffix=file_extension) as tmpfile: - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - mime_type = str(mimetypes.guess_type(tmpfile.name, strict=False)[0]) - mime_category = mime_type.split('/')[0] if '/' in mime_type else mime_type - - if file_extension in file_ingest_methods: - # Use specialized functions when possible, fallback to mimetype. Else raise error. - ingest_method = file_ingest_methods[file_extension] - _ingest_single(ingest_method, s3_path, course_name, **kwargs) - elif mime_category in mimetype_ingest_methods: - # fallback to MimeType - print("mime category", mime_category) - ingest_method = mimetype_ingest_methods[mime_category] - _ingest_single(ingest_method, s3_path, course_name, **kwargs) - else: - # No supported ingest... Fallback to attempting utf-8 decoding, otherwise fail. - try: - self._ingest_single_txt(s3_path, course_name) - success_status['success_ingest'] = s3_path - print(f"No ingest methods -- Falling back to UTF-8 INGEST... s3_path = {s3_path}") - except Exception as e: - print( - f"We don't have a ingest method for this filetype: {file_extension}. As a last-ditch effort, we tried to ingest the file as utf-8 text, but that failed too. File is unsupported: {s3_path}. UTF-8 ingest error: {e}" - ) - success_status['failure_ingest'] = { - 's3_path': - s3_path, - 'error': - f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" - } - self.posthog.capture( - 'distinct_id_of_the_user', - event='ingest_failure', - properties={ - 'course_name': - course_name, - 's3_path': - s3_paths, - 'kwargs': - kwargs, - 'error': - f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" - }) - - return success_status - except Exception as e: - err = f"❌❌ Error in /ingest: `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) # type: ignore - - success_status['failure_ingest'] = {'s3_path': s3_path, 'error': f"MAJOR ERROR DURING INGEST: {err}"} - self.posthog.capture('distinct_id_of_the_user', - event='ingest_failure', - properties={ - 'course_name': course_name, - 's3_path': s3_paths, - 'kwargs': kwargs, - 'error': err - }) - - sentry_sdk.capture_exception(e) - print(f"MAJOR ERROR IN /bulk_ingest: {str(e)}") - return success_status - - def ingest_single_web_text(self, course_name: str, base_url: str, url: str, content: str, readable_filename: str): - """Crawlee integration - """ - self.posthog.capture('distinct_id_of_the_user', - event='ingest_single_web_text_invoked', - properties={ - 'course_name': course_name, - 'base_url': base_url, - 'url': url, - 'content': content, - 'title': readable_filename - }) - success_or_failure: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} - try: - # if not, ingest the text - text = [content] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': '', - 'readable_filename': readable_filename, - 'pagenumber': '', - 'timestamp': '', - 'url': url, - 'base_url': base_url, - }] - self.split_and_upload(texts=text, metadatas=metadatas) - self.posthog.capture('distinct_id_of_the_user', - event='ingest_single_web_text_succeeded', - properties={ - 'course_name': course_name, - 'base_url': base_url, - 'url': url, - 'title': readable_filename - }) - - success_or_failure['success_ingest'] = url - return success_or_failure - except Exception as e: - err = f"❌❌ Error in (web text ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) # type: ignore - print(err) - sentry_sdk.capture_exception(e) - success_or_failure['failure_ingest'] = {'url': url, 'error': str(err)} - return success_or_failure - - def _ingest_single_py(self, s3_path: str, course_name: str, **kwargs): - try: - file_name = s3_path.split("/")[-1] - file_path = "media/" + file_name # download from s3 to local folder for ingest - - self.s3_client.download_file(os.getenv('S3_BUCKET_NAME'), s3_path, file_path) - - loader = PythonLoader(file_path) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - #print(texts) - os.remove(file_path) - - success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) - print("Python ingest: ", success_or_failure) - return success_or_failure - - except Exception as e: - err = f"❌❌ Error in (Python ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_single_vtt(self, s3_path: str, course_name: str, **kwargs): - """ - Ingest a single .vtt file from S3. - """ - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into vtt_tmpfile - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - loader = TextLoader(tmpfile.name) - documents = loader.load() - texts = [doc.page_content for doc in documents] - - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) - return success_or_failure - except Exception as e: - err = f"❌❌ Error in (VTT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_html(self, s3_path: str, course_name: str, **kwargs) -> str: - print(f"IN _ingest_html s3_path `{s3_path}` kwargs: {kwargs}") - try: - response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) - raw_html = response['Body'].read().decode('utf-8') - - soup = BeautifulSoup(raw_html, 'html.parser') - title = s3_path.replace("courses/" + course_name, "") - title = title.replace(".html", "") - title = title.replace("_", " ") - title = title.replace("/", " ") - title = title.strip() - title = title[37:] # removing the uuid prefix - text = [soup.get_text()] - - metadata: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': str(title), # adding str to avoid error: unhashable type 'slice' - 'url': kwargs.get('url', ''), - 'base_url': kwargs.get('base_url', ''), - 'pagenumber': '', - 'timestamp': '', - }] - - success_or_failure = self.split_and_upload(text, metadata) - print(f"_ingest_html: {success_or_failure}") - return success_or_failure - except Exception as e: - err: str = f"ERROR IN _ingest_html: {e}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_single_video(self, s3_path: str, course_name: str, **kwargs) -> str: - """ - Ingest a single video file from S3. - """ - print("Starting ingest video or audio") - try: - # Ensure the media directory exists - media_dir = "media" - if not os.path.exists(media_dir): - os.makedirs(media_dir) - - # check for file extension - file_ext = Path(s3_path).suffix - openai.api_key = os.getenv('VLADS_OPENAI_KEY') - transcript_list = [] - with NamedTemporaryFile(suffix=file_ext) as video_tmpfile: - # download from S3 into an video tmpfile - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=video_tmpfile) - # extract audio from video tmpfile - mp4_version = AudioSegment.from_file(video_tmpfile.name, file_ext[1:]) - - # save the extracted audio as a temporary webm file - with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as webm_tmpfile: - mp4_version.export(webm_tmpfile, format="webm") - - # check file size - file_size = os.path.getsize(webm_tmpfile.name) - # split the audio into 25MB chunks - if file_size > 26214400: - # load the webm file into audio object - full_audio = AudioSegment.from_file(webm_tmpfile.name, "webm") - file_count = file_size // 26214400 + 1 - split_segment = 35 * 60 * 1000 - start = 0 - count = 0 - - while count < file_count: - with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as split_tmp: - if count == file_count - 1: - # last segment - audio_chunk = full_audio[start:] - else: - audio_chunk = full_audio[start:split_segment] - - audio_chunk.export(split_tmp.name, format="webm") - - # transcribe the split file and store the text in dictionary - with open(split_tmp.name, "rb") as f: - transcript = openai.Audio.transcribe("whisper-1", f) - transcript_list.append(transcript['text']) # type: ignore - start += split_segment - split_segment += split_segment - count += 1 - os.remove(split_tmp.name) - else: - # transcribe the full audio - with open(webm_tmpfile.name, "rb") as f: - transcript = openai.Audio.transcribe("whisper-1", f) - transcript_list.append(transcript['text']) # type: ignore - - os.remove(webm_tmpfile.name) - - text = [txt for txt in transcript_list] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': text.index(txt), - 'url': '', - 'base_url': '', - } for txt in text] - - self.split_and_upload(texts=text, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (VIDEO ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_docx(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - - loader = Docx2txtLoader(tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (DOCX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_srt(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - import pysrt - - # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' - response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) - raw_text = response['Body'].read().decode('utf-8') - - print("UTF-8 text to ingest as SRT:", raw_text) - parsed_info = pysrt.from_string(raw_text) - text = " ".join([t.text for t in parsed_info]) # type: ignore - print(f"Final SRT ingest: {text}") - - texts = [text] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - }] - if len(text) == 0: - return "Error: SRT file appears empty. Skipping." - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (SRT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_excel(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - - loader = UnstructuredExcelLoader(tmpfile.name, mode="elements") - # loader = SRTLoader(tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (Excel/xlsx ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_image(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - """ - # Unstructured image loader makes the install too large (700MB --> 6GB. 3min -> 12 min build times). AND nobody uses it. - # The "hi_res" strategy will identify the layout of the document using detectron2. "ocr_only" uses pdfminer.six. https://unstructured-io.github.io/unstructured/core/partition.html#partition-image - loader = UnstructuredImageLoader(tmpfile.name, unstructured_kwargs={'strategy': "ocr_only"}) - documents = loader.load() - """ - - res_str = pytesseract.image_to_string(Image.open(tmpfile.name)) - print("IMAGE PARSING RESULT:", res_str) - documents = [Document(page_content=res_str)] - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (png/jpg ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_csv(self, s3_path: str, course_name: str, **kwargs) -> str: - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) - - loader = CSVLoader(file_path=tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (CSV ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_pdf(self, s3_path: str, course_name: str, **kwargs): - """ - Both OCR the PDF. And grab the first image as a PNG. - LangChain `Documents` have .metadata and .page_content attributes. - Be sure to use TemporaryFile() to avoid memory leaks! - """ - print("IN PDF ingest: s3_path: ", s3_path, "and kwargs:", kwargs) - - try: - with NamedTemporaryFile() as pdf_tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) - ### READ OCR of PDF - try: - doc = fitz.open(pdf_tmpfile.name) # type: ignore - except fitz.fitz.EmptyFileError as e: - print(f"Empty PDF file: {s3_path}") - return "Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text." - - # improve quality of the image - zoom_x = 2.0 # horizontal zoom - zoom_y = 2.0 # vertical zoom - mat = fitz.Matrix(zoom_x, zoom_y) # zoom factor 2 in each dimension - - pdf_pages_no_OCR: List[Dict] = [] - for i, page in enumerate(doc): # type: ignore - - # UPLOAD FIRST PAGE IMAGE to S3 - if i == 0: - with NamedTemporaryFile(suffix=".png") as first_page_png: - pix = page.get_pixmap(matrix=mat) - pix.save(first_page_png) # store image as a PNG - - s3_upload_path = str(Path(s3_path)).rsplit('.pdf')[0] + "-pg1-thumb.png" - first_page_png.seek(0) # Seek the file pointer back to the beginning - with open(first_page_png.name, 'rb') as f: - print("Uploading image png to S3") - self.s3_client.upload_fileobj(f, os.getenv('S3_BUCKET_NAME'), s3_upload_path) - - # Extract text - text = page.get_text().encode("utf8").decode("utf8", errors='ignore') # get plain text (is in UTF-8) - pdf_pages_no_OCR.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) - - metadatas: List[Dict[str, Any]] = [ - { - 'course_name': course_name, - 's3_path': s3_path, - 'pagenumber': page['page_number'] + 1, # +1 for human indexing - 'timestamp': '', - 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), - 'url': kwargs.get('url', ''), - 'base_url': kwargs.get('base_url', ''), - } for page in pdf_pages_no_OCR - ] - pdf_texts = [page['text'] for page in pdf_pages_no_OCR] - - # count the total number of words in the pdf_texts. If it's less than 100, we'll OCR the PDF - has_words = any(text.strip() for text in pdf_texts) - if has_words: - success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) - else: - print("⚠️ PDF IS EMPTY -- OCR-ing the PDF.") - success_or_failure = self._ocr_pdf(s3_path=s3_path, course_name=course_name, **kwargs) - - return success_or_failure - except Exception as e: - err = f"❌❌ Error in PDF ingest (no OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - return "Success" - - def _ocr_pdf(self, s3_path: str, course_name: str, **kwargs): - self.posthog.capture('distinct_id_of_the_user', - event='ocr_pdf_invoked', - properties={ - 'course_name': course_name, - 's3_path': s3_path, - }) - - pdf_pages_OCRed: List[Dict] = [] - try: - with NamedTemporaryFile() as pdf_tmpfile: - # download from S3 into pdf_tmpfile - self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) - - with pdfplumber.open(pdf_tmpfile.name) as pdf: - # for page in : - for i, page in enumerate(pdf.pages): - im = page.to_image() - text = pytesseract.image_to_string(im.original) - print("Page number: ", i, "Text: ", text[:100]) - pdf_pages_OCRed.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) +# """ +# To deploy: beam deploy ingest.py --profile caii-ncsa +# Use CAII gmail to auth. +# """ +# import asyncio +# import inspect +# import json +# import logging +# import mimetypes +# import os +# import re +# import shutil +# import time +# import traceback +# import uuid +# from pathlib import Path +# from tempfile import NamedTemporaryFile +# from typing import Any, Callable, Dict, List, Optional, Union + +# import beam +# import boto3 +# import fitz +# import openai +# import pytesseract +# import pdfplumber +# import sentry_sdk +# import supabase +# from beam import App, QueueDepthAutoscaler, Runtime # RequestLatencyAutoscaler, +# from bs4 import BeautifulSoup +# from git.repo import Repo +# from langchain.document_loaders import ( +# Docx2txtLoader, +# GitLoader, +# PythonLoader, +# TextLoader, +# UnstructuredExcelLoader, +# UnstructuredPowerPointLoader, +# ) +# from langchain.document_loaders.csv_loader import CSVLoader +# from langchain.embeddings.openai import OpenAIEmbeddings +# from langchain.schema import Document +# from langchain.text_splitter import RecursiveCharacterTextSplitter +# from langchain.vectorstores import Qdrant +# from nomic_logging import delete_from_document_map, log_to_document_map, rebuild_map +# from OpenaiEmbeddings import OpenAIAPIProcessor +# from PIL import Image +# from posthog import Posthog +# from pydub import AudioSegment +# from qdrant_client import QdrantClient, models +# from qdrant_client.models import PointStruct + +# # from langchain.schema.output_parser import StrOutputParser +# # from langchain.chat_models import AzureChatOpenAI + +# requirements = [ +# "openai<1.0", +# "supabase==2.0.2", +# "tiktoken==0.5.1", +# "boto3==1.28.79", +# "qdrant-client==1.7.3", +# "langchain==0.0.331", +# "posthog==3.1.0", +# "pysrt==1.1.2", +# "docx2txt==0.8", +# "pydub==0.25.1", +# "ffmpeg-python==0.2.0", +# "ffprobe==0.5", +# "ffmpeg==1.4", +# "PyMuPDF==1.23.6", +# "pytesseract==0.3.10", # image OCR" +# "openpyxl==3.1.2", # excel" +# "networkx==3.2.1", # unused part of excel partitioning :(" +# "python-pptx==0.6.23", +# "unstructured==0.10.29", +# "GitPython==3.1.40", +# "beautifulsoup4==4.12.2", +# "sentry-sdk==1.39.1", +# "nomic==2.0.14", +# "pdfplumber==0.11.0", # PDF OCR, better performance than Fitz/PyMuPDF in my Gies PDF testing. +# ] + +# # TODO: consider adding workers. They share CPU and memory https://docs.beam.cloud/deployment/autoscaling#worker-use-cases +# app = App("ingest", +# runtime=Runtime( +# cpu=1, +# memory="3Gi", +# image=beam.Image( +# python_version="python3.10", +# python_packages=requirements, +# commands=["apt-get update && apt-get install -y ffmpeg tesseract-ocr"], +# ), +# )) + +# # MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") +# OPENAI_API_TYPE = "azure" # "openai" or "azure" + + +# def loader(): +# """ +# The loader function will run once for each worker that starts up. https://docs.beam.cloud/deployment/loaders +# """ +# openai.api_key = os.getenv("VLADS_OPENAI_KEY") + +# # vector DB +# qdrant_client = QdrantClient( +# url=os.getenv('QDRANT_URL'), +# api_key=os.getenv('QDRANT_API_KEY'), +# ) + +# vectorstore = Qdrant(client=qdrant_client, +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], +# embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE, +# openai_api_key=os.getenv('VLADS_OPENAI_KEY'))) + +# # S3 +# s3_client = boto3.client( +# 's3', +# aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), +# aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), +# ) + +# # Create a Supabase client +# supabase_client = supabase.create_client( # type: ignore +# supabase_url=os.environ['SUPABASE_URL'], supabase_key=os.environ['SUPABASE_API_KEY']) + +# # llm = AzureChatOpenAI( +# # temperature=0, +# # deployment_name=os.getenv('AZURE_OPENAI_ENGINE'), #type:ignore +# # openai_api_base=os.getenv('AZURE_OPENAI_ENDPOINT'), #type:ignore +# # openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore +# # openai_api_version=os.getenv('OPENAI_API_VERSION'), #type:ignore +# # openai_api_type=OPENAI_API_TYPE) + +# posthog = Posthog(sync_mode=True, project_api_key=os.environ['POSTHOG_API_KEY'], host='https://app.posthog.com') +# sentry_sdk.init( +# dsn="https://examplePublicKey@o0.ingest.sentry.io/0", + +# # Enable performance monitoring +# enable_tracing=True, +# ) + +# return qdrant_client, vectorstore, s3_client, supabase_client, posthog + + +# # autoscaler = RequestLatencyAutoscaler(desired_latency=30, max_replicas=2) +# autoscaler = QueueDepthAutoscaler(max_tasks_per_replica=300, max_replicas=3) + + +# # Triggers determine how your app is deployed +# # @app.rest_api( +# @app.task_queue( +# workers=4, +# callback_url='https://uiuc-chat-git-ingestprogresstracking-kastanday.vercel.app/api/UIUC-api/ingestTaskCallback', +# max_pending_tasks=15_000, +# max_retries=3, +# timeout=-1, +# loader=loader, +# autoscaler=autoscaler) +# def ingest(**inputs: Dict[str, Any]): + +# qdrant_client, vectorstore, s3_client, supabase_client, posthog = inputs["context"] + +# course_name: List[str] | str = inputs.get('course_name', '') +# s3_paths: List[str] | str = inputs.get('s3_paths', '') +# url: List[str] | str | None = inputs.get('url', None) +# base_url: List[str] | str | None = inputs.get('base_url', None) +# readable_filename: List[str] | str = inputs.get('readable_filename', '') +# content: str | None = inputs.get('content', None) # is webtext if content exists + +# print( +# f"In top of /ingest route. course: {course_name}, s3paths: {s3_paths}, readable_filename: {readable_filename}, base_url: {base_url}, url: {url}, content: {content}" +# ) + +# ingester = Ingest(qdrant_client, vectorstore, s3_client, supabase_client, posthog) + +# def run_ingest(course_name, s3_paths, base_url, url, readable_filename, content): +# if content: +# return ingester.ingest_single_web_text(course_name, base_url, url, content, readable_filename) +# elif readable_filename == '': +# return ingester.bulk_ingest(course_name, s3_paths, base_url=base_url, url=url) +# else: +# return ingester.bulk_ingest(course_name, +# s3_paths, +# readable_filename=readable_filename, +# base_url=base_url, +# url=url) + +# # First try +# success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) + +# # retries +# num_retires = 5 +# for retry_num in range(1, num_retires): +# if isinstance(success_fail_dict, str): +# print(f"STRING ERROR: {success_fail_dict = }") +# success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) +# time.sleep(13 * retry_num) # max is 65 +# elif success_fail_dict['failure_ingest']: +# print(f"Ingest failure -- Retry attempt {retry_num}. File: {success_fail_dict}") +# # s3_paths = success_fail_dict['failure_ingest'] # retry only failed paths.... what if this is a URL instead? +# success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) +# time.sleep(13 * retry_num) # max is 65 +# else: +# break + +# # Final failure / success check +# if success_fail_dict['failure_ingest']: +# print(f"INGEST FAILURE -- About to send to supabase. success_fail_dict: {success_fail_dict}") +# document = { +# "course_name": +# course_name, +# "s3_path": +# s3_paths, +# "readable_filename": +# readable_filename, +# "url": +# url, +# "base_url": +# base_url, +# "error": +# success_fail_dict['failure_ingest']['error'] +# if isinstance(success_fail_dict['failure_ingest'], dict) else success_fail_dict['failure_ingest'] +# } +# response = supabase_client.table('documents_failed').insert(document).execute() # type: ignore +# print(f"Supabase ingest failure response: {response}") +# else: +# # Success case: rebuild nomic document map after all ingests are done +# # rebuild_status = rebuild_map(str(course_name), map_type='document') +# pass + +# print(f"Final success_fail_dict: {success_fail_dict}") +# return json.dumps(success_fail_dict) + + +# class Ingest(): + +# def __init__(self, qdrant_client, vectorstore, s3_client, supabase_client, posthog): +# self.qdrant_client = qdrant_client +# self.vectorstore = vectorstore +# self.s3_client = s3_client +# self.supabase_client = supabase_client +# self.posthog = posthog + +# def bulk_ingest(self, course_name: str, s3_paths: Union[str, List[str]], +# **kwargs) -> Dict[str, None | str | Dict[str, str]]: +# """ +# Bulk ingest a list of s3 paths into the vectorstore, and also into the supabase database. +# -> Dict[str, str | Dict[str, str]] +# """ + +# def _ingest_single(ingest_method: Callable, s3_path, *args, **kwargs): +# """Handle running an arbitrary ingest function for an individual file.""" +# # RUN INGEST METHOD +# ret = ingest_method(s3_path, *args, **kwargs) +# if ret == "Success": +# success_status['success_ingest'] = str(s3_path) +# else: +# success_status['failure_ingest'] = {'s3_path': str(s3_path), 'error': str(ret)} + +# # 👇👇👇👇 ADD NEW INGEST METHODS HERE 👇👇👇👇🎉 +# file_ingest_methods = { +# '.html': self._ingest_html, +# '.py': self._ingest_single_py, +# '.pdf': self._ingest_single_pdf, +# '.txt': self._ingest_single_txt, +# '.md': self._ingest_single_txt, +# '.srt': self._ingest_single_srt, +# '.vtt': self._ingest_single_vtt, +# '.docx': self._ingest_single_docx, +# '.ppt': self._ingest_single_ppt, +# '.pptx': self._ingest_single_ppt, +# '.xlsx': self._ingest_single_excel, +# '.xls': self._ingest_single_excel, +# '.csv': self._ingest_single_csv, +# '.png': self._ingest_single_image, +# '.jpg': self._ingest_single_image, +# } + +# # Ingest methods via MIME type (more general than filetype) +# mimetype_ingest_methods = { +# 'video': self._ingest_single_video, +# 'audio': self._ingest_single_video, +# 'text': self._ingest_single_txt, +# 'image': self._ingest_single_image, +# } +# # 👆👆👆👆 ADD NEW INGEST METHODhe 👆👆👆👆🎉 + +# print(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}") +# success_status: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} +# try: +# if isinstance(s3_paths, str): +# s3_paths = [s3_paths] + +# for s3_path in s3_paths: +# file_extension = Path(s3_path).suffix +# with NamedTemporaryFile(suffix=file_extension) as tmpfile: +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) +# mime_type = str(mimetypes.guess_type(tmpfile.name, strict=False)[0]) +# mime_category = mime_type.split('/')[0] if '/' in mime_type else mime_type + +# if file_extension in file_ingest_methods: +# # Use specialized functions when possible, fallback to mimetype. Else raise error. +# ingest_method = file_ingest_methods[file_extension] +# _ingest_single(ingest_method, s3_path, course_name, **kwargs) +# elif mime_category in mimetype_ingest_methods: +# # fallback to MimeType +# print("mime category", mime_category) +# ingest_method = mimetype_ingest_methods[mime_category] +# _ingest_single(ingest_method, s3_path, course_name, **kwargs) +# else: +# # No supported ingest... Fallback to attempting utf-8 decoding, otherwise fail. +# try: +# self._ingest_single_txt(s3_path, course_name) +# success_status['success_ingest'] = s3_path +# print(f"No ingest methods -- Falling back to UTF-8 INGEST... s3_path = {s3_path}") +# except Exception as e: +# print( +# f"We don't have a ingest method for this filetype: {file_extension}. As a last-ditch effort, we tried to ingest the file as utf-8 text, but that failed too. File is unsupported: {s3_path}. UTF-8 ingest error: {e}" +# ) +# success_status['failure_ingest'] = { +# 's3_path': +# s3_path, +# 'error': +# f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" +# } +# self.posthog.capture( +# 'distinct_id_of_the_user', +# event='ingest_failure', +# properties={ +# 'course_name': +# course_name, +# 's3_path': +# s3_paths, +# 'kwargs': +# kwargs, +# 'error': +# f"We don't have a ingest method for this filetype: {file_extension} (with generic type {mime_type}), for file: {s3_path}" +# }) + +# return success_status +# except Exception as e: +# err = f"❌❌ Error in /ingest: `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) # type: ignore + +# success_status['failure_ingest'] = {'s3_path': s3_path, 'error': f"MAJOR ERROR DURING INGEST: {err}"} +# self.posthog.capture('distinct_id_of_the_user', +# event='ingest_failure', +# properties={ +# 'course_name': course_name, +# 's3_path': s3_paths, +# 'kwargs': kwargs, +# 'error': err +# }) + +# sentry_sdk.capture_exception(e) +# print(f"MAJOR ERROR IN /bulk_ingest: {str(e)}") +# return success_status + +# def ingest_single_web_text(self, course_name: str, base_url: str, url: str, content: str, readable_filename: str): +# """Crawlee integration +# """ +# self.posthog.capture('distinct_id_of_the_user', +# event='ingest_single_web_text_invoked', +# properties={ +# 'course_name': course_name, +# 'base_url': base_url, +# 'url': url, +# 'content': content, +# 'title': readable_filename +# }) +# success_or_failure: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} +# try: +# # if not, ingest the text +# text = [content] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': '', +# 'readable_filename': readable_filename, +# 'pagenumber': '', +# 'timestamp': '', +# 'url': url, +# 'base_url': base_url, +# }] +# self.split_and_upload(texts=text, metadatas=metadatas) +# self.posthog.capture('distinct_id_of_the_user', +# event='ingest_single_web_text_succeeded', +# properties={ +# 'course_name': course_name, +# 'base_url': base_url, +# 'url': url, +# 'title': readable_filename +# }) + +# success_or_failure['success_ingest'] = url +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in (web text ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) # type: ignore +# print(err) +# sentry_sdk.capture_exception(e) +# success_or_failure['failure_ingest'] = {'url': url, 'error': str(err)} +# return success_or_failure + +# def _ingest_single_py(self, s3_path: str, course_name: str, **kwargs): +# try: +# file_name = s3_path.split("/")[-1] +# file_path = "media/" + file_name # download from s3 to local folder for ingest + +# self.s3_client.download_file(os.getenv('S3_BUCKET_NAME'), s3_path, file_path) + +# loader = PythonLoader(file_path) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] + +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] +# #print(texts) +# os.remove(file_path) + +# success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) +# print("Python ingest: ", success_or_failure) +# return success_or_failure + +# except Exception as e: +# err = f"❌❌ Error in (Python ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_single_vtt(self, s3_path: str, course_name: str, **kwargs): +# """ +# Ingest a single .vtt file from S3. +# """ +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into vtt_tmpfile +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) +# loader = TextLoader(tmpfile.name) +# documents = loader.load() +# texts = [doc.page_content for doc in documents] + +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in (VTT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_html(self, s3_path: str, course_name: str, **kwargs) -> str: +# print(f"IN _ingest_html s3_path `{s3_path}` kwargs: {kwargs}") +# try: +# response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) +# raw_html = response['Body'].read().decode('utf-8') + +# soup = BeautifulSoup(raw_html, 'html.parser') +# title = s3_path.replace("courses/" + course_name, "") +# title = title.replace(".html", "") +# title = title.replace("_", " ") +# title = title.replace("/", " ") +# title = title.strip() +# title = title[37:] # removing the uuid prefix +# text = [soup.get_text()] + +# metadata: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': str(title), # adding str to avoid error: unhashable type 'slice' +# 'url': kwargs.get('url', ''), +# 'base_url': kwargs.get('base_url', ''), +# 'pagenumber': '', +# 'timestamp': '', +# }] + +# success_or_failure = self.split_and_upload(text, metadata) +# print(f"_ingest_html: {success_or_failure}") +# return success_or_failure +# except Exception as e: +# err: str = f"ERROR IN _ingest_html: {e}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_single_video(self, s3_path: str, course_name: str, **kwargs) -> str: +# """ +# Ingest a single video file from S3. +# """ +# print("Starting ingest video or audio") +# try: +# # Ensure the media directory exists +# media_dir = "media" +# if not os.path.exists(media_dir): +# os.makedirs(media_dir) + +# # check for file extension +# file_ext = Path(s3_path).suffix +# openai.api_key = os.getenv('VLADS_OPENAI_KEY') +# transcript_list = [] +# with NamedTemporaryFile(suffix=file_ext) as video_tmpfile: +# # download from S3 into an video tmpfile +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=video_tmpfile) +# # extract audio from video tmpfile +# mp4_version = AudioSegment.from_file(video_tmpfile.name, file_ext[1:]) + +# # save the extracted audio as a temporary webm file +# with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as webm_tmpfile: +# mp4_version.export(webm_tmpfile, format="webm") + +# # check file size +# file_size = os.path.getsize(webm_tmpfile.name) +# # split the audio into 25MB chunks +# if file_size > 26214400: +# # load the webm file into audio object +# full_audio = AudioSegment.from_file(webm_tmpfile.name, "webm") +# file_count = file_size // 26214400 + 1 +# split_segment = 35 * 60 * 1000 +# start = 0 +# count = 0 + +# while count < file_count: +# with NamedTemporaryFile(suffix=".webm", dir=media_dir, delete=False) as split_tmp: +# if count == file_count - 1: +# # last segment +# audio_chunk = full_audio[start:] +# else: +# audio_chunk = full_audio[start:split_segment] + +# audio_chunk.export(split_tmp.name, format="webm") + +# # transcribe the split file and store the text in dictionary +# with open(split_tmp.name, "rb") as f: +# transcript = openai.Audio.transcribe("whisper-1", f) +# transcript_list.append(transcript['text']) # type: ignore +# start += split_segment +# split_segment += split_segment +# count += 1 +# os.remove(split_tmp.name) +# else: +# # transcribe the full audio +# with open(webm_tmpfile.name, "rb") as f: +# transcript = openai.Audio.transcribe("whisper-1", f) +# transcript_list.append(transcript['text']) # type: ignore + +# os.remove(webm_tmpfile.name) + +# text = [txt for txt in transcript_list] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': text.index(txt), +# 'url': '', +# 'base_url': '', +# } for txt in text] + +# self.split_and_upload(texts=text, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (VIDEO ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_docx(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) + +# loader = Docx2txtLoader(tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (DOCX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_srt(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# import pysrt + +# # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' +# response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) +# raw_text = response['Body'].read().decode('utf-8') + +# print("UTF-8 text to ingest as SRT:", raw_text) +# parsed_info = pysrt.from_string(raw_text) +# text = " ".join([t.text for t in parsed_info]) # type: ignore +# print(f"Final SRT ingest: {text}") + +# texts = [text] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# }] +# if len(text) == 0: +# return "Error: SRT file appears empty. Skipping." + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (SRT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_excel(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) + +# loader = UnstructuredExcelLoader(tmpfile.name, mode="elements") +# # loader = SRTLoader(tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (Excel/xlsx ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_image(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) +# """ +# # Unstructured image loader makes the install too large (700MB --> 6GB. 3min -> 12 min build times). AND nobody uses it. +# # The "hi_res" strategy will identify the layout of the document using detectron2. "ocr_only" uses pdfminer.six. https://unstructured-io.github.io/unstructured/core/partition.html#partition-image +# loader = UnstructuredImageLoader(tmpfile.name, unstructured_kwargs={'strategy': "ocr_only"}) +# documents = loader.load() +# """ + +# res_str = pytesseract.image_to_string(Image.open(tmpfile.name)) +# print("IMAGE PARSING RESULT:", res_str) +# documents = [Document(page_content=res_str)] + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (png/jpg ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_csv(self, s3_path: str, course_name: str, **kwargs) -> str: +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=tmpfile) + +# loader = CSVLoader(file_path=tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (CSV ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_pdf(self, s3_path: str, course_name: str, **kwargs): +# """ +# Both OCR the PDF. And grab the first image as a PNG. +# LangChain `Documents` have .metadata and .page_content attributes. +# Be sure to use TemporaryFile() to avoid memory leaks! +# """ +# print("IN PDF ingest: s3_path: ", s3_path, "and kwargs:", kwargs) + +# try: +# with NamedTemporaryFile() as pdf_tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) +# ### READ OCR of PDF +# try: +# doc = fitz.open(pdf_tmpfile.name) # type: ignore +# except fitz.fitz.EmptyFileError as e: +# print(f"Empty PDF file: {s3_path}") +# return "Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text." + +# # improve quality of the image +# zoom_x = 2.0 # horizontal zoom +# zoom_y = 2.0 # vertical zoom +# mat = fitz.Matrix(zoom_x, zoom_y) # zoom factor 2 in each dimension + +# pdf_pages_no_OCR: List[Dict] = [] +# for i, page in enumerate(doc): # type: ignore + +# # UPLOAD FIRST PAGE IMAGE to S3 +# if i == 0: +# with NamedTemporaryFile(suffix=".png") as first_page_png: +# pix = page.get_pixmap(matrix=mat) +# pix.save(first_page_png) # store image as a PNG + +# s3_upload_path = str(Path(s3_path)).rsplit('.pdf')[0] + "-pg1-thumb.png" +# first_page_png.seek(0) # Seek the file pointer back to the beginning +# with open(first_page_png.name, 'rb') as f: +# print("Uploading image png to S3") +# self.s3_client.upload_fileobj(f, os.getenv('S3_BUCKET_NAME'), s3_upload_path) + +# # Extract text +# text = page.get_text().encode("utf8").decode("utf8", errors='ignore') # get plain text (is in UTF-8) +# pdf_pages_no_OCR.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) + +# metadatas: List[Dict[str, Any]] = [ +# { +# 'course_name': course_name, +# 's3_path': s3_path, +# 'pagenumber': page['page_number'] + 1, # +1 for human indexing +# 'timestamp': '', +# 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), +# 'url': kwargs.get('url', ''), +# 'base_url': kwargs.get('base_url', ''), +# } for page in pdf_pages_no_OCR +# ] +# pdf_texts = [page['text'] for page in pdf_pages_no_OCR] + +# # count the total number of words in the pdf_texts. If it's less than 100, we'll OCR the PDF +# has_words = any(text.strip() for text in pdf_texts) +# if has_words: +# success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) +# else: +# print("⚠️ PDF IS EMPTY -- OCR-ing the PDF.") +# success_or_failure = self._ocr_pdf(s3_path=s3_path, course_name=course_name, **kwargs) + +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in PDF ingest (no OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) # type: ignore +# print(err) +# sentry_sdk.capture_exception(e) +# return err +# return "Success" + +# def _ocr_pdf(self, s3_path: str, course_name: str, **kwargs): +# self.posthog.capture('distinct_id_of_the_user', +# event='ocr_pdf_invoked', +# properties={ +# 'course_name': course_name, +# 's3_path': s3_path, +# }) + +# pdf_pages_OCRed: List[Dict] = [] +# try: +# with NamedTemporaryFile() as pdf_tmpfile: +# # download from S3 into pdf_tmpfile +# self.s3_client.download_fileobj(Bucket=os.getenv('S3_BUCKET_NAME'), Key=s3_path, Fileobj=pdf_tmpfile) + +# with pdfplumber.open(pdf_tmpfile.name) as pdf: +# # for page in : +# for i, page in enumerate(pdf.pages): +# im = page.to_image() +# text = pytesseract.image_to_string(im.original) +# print("Page number: ", i, "Text: ", text[:100]) +# pdf_pages_OCRed.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) - metadatas: List[Dict[str, Any]] = [ - { - 'course_name': course_name, - 's3_path': s3_path, - 'pagenumber': page['page_number'] + 1, # +1 for human indexing - 'timestamp': '', - 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), - 'url': kwargs.get('url', ''), - 'base_url': kwargs.get('base_url', ''), - } for page in pdf_pages_OCRed - ] - pdf_texts = [page['text'] for page in pdf_pages_OCRed] - self.posthog.capture('distinct_id_of_the_user', - event='ocr_pdf_succeeded', - properties={ - 'course_name': course_name, - 's3_path': s3_path, - }) - - has_words = any(text.strip() for text in pdf_texts) - if not has_words: - raise ValueError("Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text.") - - success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) - return success_or_failure - except Exception as e: - err = f"❌❌ Error in PDF ingest (with OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc() - print(err) - sentry_sdk.capture_exception(e) - return err - - def _ingest_single_txt(self, s3_path: str, course_name: str, **kwargs) -> str: - """Ingest a single .txt or .md file from S3. - Args: - s3_path (str): A path to a .txt file in S3 - course_name (str): The name of the course - Returns: - str: "Success" or an error message - """ - print("In text ingest, UTF-8") - try: - # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' - response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) - text = response['Body'].read().decode('utf-8') - print("UTF-8 text to ignest (from s3)", text) - text = [text] - - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - }] - print("Prior to ingest", metadatas) - - success_or_failure = self.split_and_upload(texts=text, metadatas=metadatas) - return success_or_failure - except Exception as e: - err = f"❌❌ Error in (TXT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def _ingest_single_ppt(self, s3_path: str, course_name: str, **kwargs) -> str: - """ - Ingest a single .ppt or .pptx file from S3. - """ - try: - with NamedTemporaryFile() as tmpfile: - # download from S3 into pdf_tmpfile - #print("in ingest PPTX") - self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) - - loader = UnstructuredPowerPointLoader(tmpfile.name) - documents = loader.load() - - texts = [doc.page_content for doc in documents] - metadatas: List[Dict[str, Any]] = [{ - 'course_name': course_name, - 's3_path': s3_path, - 'readable_filename': kwargs.get('readable_filename', - Path(s3_path).name[37:]), - 'pagenumber': '', - 'timestamp': '', - 'url': '', - 'base_url': '', - } for doc in documents] - - self.split_and_upload(texts=texts, metadatas=metadatas) - return "Success" - except Exception as e: - err = f"❌❌ Error in (PPTX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( - ) - print(err) - sentry_sdk.capture_exception(e) - return str(err) - - def ingest_github(self, github_url: str, course_name: str) -> str: - """ - Clones the given GitHub URL and uses Langchain to load data. - 1. Clone the repo - 2. Use Langchain to load the data - 3. Pass to split_and_upload() - Args: - github_url (str): The Github Repo URL to be ingested. - course_name (str): The name of the course in our system. - - Returns: - _type_: Success or error message. - """ - try: - repo_path = "media/cloned_repo" - repo = Repo.clone_from(github_url, to_path=repo_path, depth=1, clone_submodules=False) - branch = repo.head.reference - - loader = GitLoader(repo_path="media/cloned_repo", branch=str(branch)) - data = loader.load() - shutil.rmtree("media/cloned_repo") - # create metadata for each file in data - - for doc in data: - texts = doc.page_content - metadatas: Dict[str, Any] = { - 'course_name': course_name, - 's3_path': '', - 'readable_filename': doc.metadata['file_name'], - 'url': f"{github_url}/blob/main/{doc.metadata['file_path']}", - 'pagenumber': '', - 'timestamp': '', - } - self.split_and_upload(texts=[texts], metadatas=[metadatas]) - return "Success" - except Exception as e: - err = f"❌❌ Error in (GITHUB ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" - print(err) - sentry_sdk.capture_exception(e) - return err - - def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]]): - """ This is usually the last step of document ingest. Chunk & upload to Qdrant (and Supabase.. todo). - Takes in Text and Metadata (from Langchain doc loaders) and splits / uploads to Qdrant. - - good examples here: https://langchain.readthedocs.io/en/latest/modules/utils/combine_docs_examples/textsplitter.html - - Args: - texts (List[str]): _description_ - metadatas (List[Dict[str, Any]]): _description_ - """ - # return "Success" - self.posthog.capture('distinct_id_of_the_user', - event='split_and_upload_invoked', - properties={ - 'course_name': metadatas[0].get('course_name', None), - 's3_path': metadatas[0].get('s3_path', None), - 'readable_filename': metadatas[0].get('readable_filename', None), - 'url': metadatas[0].get('url', None), - 'base_url': metadatas[0].get('base_url', None), - }) - - print(f"In split and upload. Metadatas: {metadatas}") - print(f"Texts: {texts}") - assert len(texts) == len( - metadatas - ), f'must have equal number of text strings and metadata dicts. len(texts) is {len(texts)}. len(metadatas) is {len(metadatas)}' - - try: - text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=1000, - chunk_overlap=150, - separators=[ - "\n\n", "\n", ". ", " ", "" - ] # try to split on paragraphs... fallback to sentences, then chars, ensure we always fit in context window - ) - contexts: List[Document] = text_splitter.create_documents(texts=texts, metadatas=metadatas) - input_texts = [{'input': context.page_content, 'model': 'text-embedding-ada-002'} for context in contexts] - - # check for duplicates - is_duplicate = self.check_for_duplicates(input_texts, metadatas) - if is_duplicate: - self.posthog.capture('distinct_id_of_the_user', - event='split_and_upload_succeeded', - properties={ - 'course_name': metadatas[0].get('course_name', None), - 's3_path': metadatas[0].get('s3_path', None), - 'readable_filename': metadatas[0].get('readable_filename', None), - 'url': metadatas[0].get('url', None), - 'base_url': metadatas[0].get('base_url', None), - 'is_duplicate': True, - }) - return "Success" - - # adding chunk index to metadata for parent doc retrieval - for i, context in enumerate(contexts): - context.metadata['chunk_index'] = i - - oai = OpenAIAPIProcessor( - input_prompts_list=input_texts, - request_url='https://api.openai.com/v1/embeddings', - api_key=os.getenv('VLADS_OPENAI_KEY'), - # request_url='https://uiuc-chat-canada-east.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-05-15', - # api_key=os.getenv('AZURE_OPENAI_KEY'), - max_requests_per_minute=5_000, - max_tokens_per_minute=300_000, - max_attempts=20, - logging_level=logging.INFO, - token_encoding_name='cl100k_base') # nosec -- reasonable bandit error suppression - asyncio.run(oai.process_api_requests_from_file()) - # parse results into dict of shape page_content -> embedding - embeddings_dict: dict[str, List[float]] = { - item[0]['input']: item[1]['data'][0]['embedding'] for item in oai.results - } - - ### BULK upload to Qdrant ### - vectors: list[PointStruct] = [] - for context in contexts: - # !DONE: Updated the payload so each key is top level (no more payload.metadata.course_name. Instead, use payload.course_name), great for creating indexes. - upload_metadata = {**context.metadata, "page_content": context.page_content} - vectors.append( - PointStruct(id=str(uuid.uuid4()), vector=embeddings_dict[context.page_content], payload=upload_metadata)) - - self.qdrant_client.upsert( - collection_name=os.environ['QDRANT_COLLECTION_NAME'], # type: ignore - points=vectors # type: ignore - ) - ### Supabase SQL ### - contexts_for_supa = [{ - "text": context.page_content, - "pagenumber": context.metadata.get('pagenumber'), - "timestamp": context.metadata.get('timestamp'), - "chunk_index": context.metadata.get('chunk_index'), - "embedding": embeddings_dict[context.page_content] - } for context in contexts] - - document = { - "course_name": contexts[0].metadata.get('course_name'), - "s3_path": contexts[0].metadata.get('s3_path'), - "readable_filename": contexts[0].metadata.get('readable_filename'), - "url": contexts[0].metadata.get('url'), - "base_url": contexts[0].metadata.get('base_url'), - "contexts": contexts_for_supa, - } - - response = self.supabase_client.table( - os.getenv('SUPABASE_DOCUMENTS_TABLE')).insert(document).execute() # type: ignore - - # add to Nomic document map - if len(response.data) > 0: - course_name = contexts[0].metadata.get('course_name') - log_to_document_map(course_name) - - self.posthog.capture('distinct_id_of_the_user', - event='split_and_upload_succeeded', - properties={ - 'course_name': metadatas[0].get('course_name', None), - 's3_path': metadatas[0].get('s3_path', None), - 'readable_filename': metadatas[0].get('readable_filename', None), - 'url': metadatas[0].get('url', None), - 'base_url': metadatas[0].get('base_url', None), - }) - print("successful END OF split_and_upload") - return "Success" - except Exception as e: - err: str = f"ERROR IN split_and_upload(): Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - - def check_for_duplicates(self, texts: List[Dict], metadatas: List[Dict[str, Any]]) -> bool: - """ - For given metadata, fetch docs from Supabase based on S3 path or URL. - If docs exists, concatenate the texts and compare with current texts, if same, return True. - """ - doc_table = os.getenv('SUPABASE_DOCUMENTS_TABLE', '') - course_name = metadatas[0]['course_name'] - incoming_s3_path = metadatas[0]['s3_path'] - url = metadatas[0]['url'] - original_filename = incoming_s3_path.split('/')[-1][37:] # remove the 37-char uuid prefix - - # check if uuid exists in s3_path -- not all s3_paths have uuids! - incoming_filename = incoming_s3_path.split('/')[-1] - pattern = re.compile(r'[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}', - re.I) # uuid V4 pattern, and v4 only. - if bool(pattern.search(incoming_filename)): - # uuid pattern exists -- remove the uuid and proceed with duplicate checking - original_filename = incoming_filename[37:] - else: - # do not remove anything and proceed with duplicate checking - original_filename = incoming_filename - - if incoming_s3_path: - filename = incoming_s3_path - supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( - 'course_name', course_name).like('s3_path', '%' + original_filename + '%').order('id', desc=True).execute() - supabase_contents = supabase_contents.data - elif url: - filename = url - supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( - 'course_name', course_name).eq('url', url).order('id', desc=True).execute() - supabase_contents = supabase_contents.data - else: - filename = None - supabase_contents = [] - - supabase_whole_text = "" - if len(supabase_contents) > 0: # if a doc with same filename exists in Supabase - # concatenate texts - supabase_contexts = supabase_contents[0] - for text in supabase_contexts['contexts']: - supabase_whole_text += text['text'] - - current_whole_text = "" - for text in texts: - current_whole_text += text['input'] - - if supabase_whole_text == current_whole_text: # matches the previous file - print(f"Duplicate ingested! 📄 s3_path: {filename}.") - return True - - else: # the file is updated - print(f"Updated file detected! Same filename, new contents. 📄 s3_path: {filename}") - - # call the delete function on older docs - for content in supabase_contents: - print("older s3_path to be deleted: ", content['s3_path']) - delete_status = self.delete_data(course_name, content['s3_path'], '') - print("delete_status: ", delete_status) - return False - - else: # filename does not already exist in Supabase, so its a brand new file - print(f"NOT a duplicate! 📄s3_path: {filename}") - return False - - def delete_data(self, course_name: str, s3_path: str, source_url: str): - """Delete file from S3, Qdrant, and Supabase.""" - print(f"Deleting {s3_path} from S3, Qdrant, and Supabase for course {course_name}") - # add delete from doc map logic here - try: - # Delete file from S3 - bucket_name = os.getenv('S3_BUCKET_NAME') - - # Delete files by S3 path - if s3_path: - try: - self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path) - except Exception as e: - print("Error in deleting file from s3:", e) - sentry_sdk.capture_exception(e) - # Delete from Qdrant - # docs for nested keys: https://qdrant.tech/documentation/concepts/filtering/#nested-key - # Qdrant "points" look like this: Record(id='000295ca-bd28-ac4a-6f8d-c245f7377f90', payload={'metadata': {'course_name': 'zotero-extreme', 'pagenumber_or_timestamp': 15, 'readable_filename': 'Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf', 's3_path': 'courses/zotero-extreme/Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf'}, 'page_content': '18 \nDunlosky et al.\n3.3 Effects in representative educational contexts. Sev-\neral of the large summarization-training studies have been \nconducted in regular classrooms, indicating the feasibility of \ndoing so. For example, the study by A. King (1992) took place \nin the context of a remedial study-skills course for undergrad-\nuates, and the study by Rinehart et al. (1986) took place in \nsixth-grade classrooms, with the instruction led by students \nregular teachers. In these and other cases, students benefited \nfrom the classroom training. We suspect it may actually be \nmore feasible to conduct these kinds of training ... - try: - self.qdrant_client.delete( - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - points_selector=models.Filter(must=[ - models.FieldCondition( - key="s3_path", - match=models.MatchValue(value=s3_path), - ), - ]), - ) - except Exception as e: - if "timed out" in str(e): - # Timed out is fine. Still deletes. - # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 - pass - else: - print("Error in deleting file from Qdrant:", e) - sentry_sdk.capture_exception(e) - try: - # delete from Nomic - response = self.supabase_client.from_( - os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq('s3_path', s3_path).eq( - 'course_name', course_name).execute() - data = response.data[0] #single record fetched - nomic_ids_to_delete = [] - context_count = len(data['contexts']) - for i in range(1, context_count + 1): - nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) - - # delete from Nomic - delete_from_document_map(course_name, nomic_ids_to_delete) - except Exception as e: - print("Error in deleting file from Nomic:", e) - sentry_sdk.capture_exception(e) - - try: - self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( - 'course_name', course_name).execute() - except Exception as e: - print("Error in deleting file from supabase:", e) - sentry_sdk.capture_exception(e) - - # Delete files by their URL identifier - elif source_url: - try: - # Delete from Qdrant - self.qdrant_client.delete( - collection_name=os.environ['QDRANT_COLLECTION_NAME'], - points_selector=models.Filter(must=[ - models.FieldCondition( - key="url", - match=models.MatchValue(value=source_url), - ), - ]), - ) - except Exception as e: - if "timed out" in str(e): - # Timed out is fine. Still deletes. - # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 - pass - else: - print("Error in deleting file from Qdrant:", e) - sentry_sdk.capture_exception(e) - try: - # delete from Nomic - response = self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, url, contexts").eq( - 'url', source_url).eq('course_name', course_name).execute() - data = response.data[0] #single record fetched - nomic_ids_to_delete = [] - context_count = len(data['contexts']) - for i in range(1, context_count + 1): - nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) - - # delete from Nomic - delete_from_document_map(course_name, nomic_ids_to_delete) - except Exception as e: - print("Error in deleting file from Nomic:", e) - sentry_sdk.capture_exception(e) - - try: - # delete from Supabase - self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('url', source_url).eq( - 'course_name', course_name).execute() - except Exception as e: - print("Error in deleting file from supabase:", e) - sentry_sdk.capture_exception(e) - - # Delete from Supabase - return "Success" - except Exception as e: - err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) - sentry_sdk.capture_exception(e) - return err - - # def ingest_coursera(self, coursera_course_name: str, course_name: str) -> str: - # """ Download all the files from a coursera course and ingest them. - - # 1. Download the coursera content. - # 2. Upload to S3 (so users can view it) - # 3. Run everything through the ingest_bulk method. - - # Args: - # coursera_course_name (str): The name of the coursera course. - # course_name (str): The name of the course in our system. - - # Returns: - # _type_: Success or error message. - # """ - # certificate = "-ca 'FVhVoDp5cb-ZaoRr5nNJLYbyjCLz8cGvaXzizqNlQEBsG5wSq7AHScZGAGfC1nI0ehXFvWy1NG8dyuIBF7DLMA.X3cXsDvHcOmSdo3Fyvg27Q.qyGfoo0GOHosTVoSMFy-gc24B-_BIxJtqblTzN5xQWT3hSntTR1DMPgPQKQmfZh_40UaV8oZKKiF15HtZBaLHWLbpEpAgTg3KiTiU1WSdUWueo92tnhz-lcLeLmCQE2y3XpijaN6G4mmgznLGVsVLXb-P3Cibzz0aVeT_lWIJNrCsXrTFh2HzFEhC4FxfTVqS6cRsKVskPpSu8D9EuCQUwJoOJHP_GvcME9-RISBhi46p-Z1IQZAC4qHPDhthIJG4bJqpq8-ZClRL3DFGqOfaiu5y415LJcH--PRRKTBnP7fNWPKhcEK2xoYQLr9RxBVL3pzVPEFyTYtGg6hFIdJcjKOU11AXAnQ-Kw-Gb_wXiHmu63veM6T8N2dEkdqygMre_xMDT5NVaP3xrPbA4eAQjl9yov4tyX4AQWMaCS5OCbGTpMTq2Y4L0Mbz93MHrblM2JL_cBYa59bq7DFK1IgzmOjFhNG266mQlC9juNcEhc'" - # always_use_flags = "-u kastanvday@gmail.com -p hSBsLaF5YM469# --ignore-formats mp4 --subtitle-language en --path ./coursera-dl" - - # try: - # subprocess.run( - # f"coursera-dl {always_use_flags} {certificate} {coursera_course_name}", - # check=True, - # shell=True, # nosec -- reasonable bandit error suppression - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE) # capture_output=True, - # dl_results_path = os.path.join('coursera-dl', coursera_course_name) - # s3_paths: Union[List, None] = upload_data_files_to_s3(course_name, dl_results_path) - - # if s3_paths is None: - # return "Error: No files found in the coursera-dl directory" - - # print("starting bulk ingest") - # start_time = time.monotonic() - # self.bulk_ingest(s3_paths, course_name) - # print("completed bulk ingest") - # print(f"⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") - - # # Cleanup the coursera downloads - # shutil.rmtree(dl_results_path) - - # return "Success" - # except Exception as e: - # err: str = f"Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - # print(err) - # return err - - # def list_files_recursively(self, bucket, prefix): - # all_files = [] - # continuation_token = None - - # while True: - # list_objects_kwargs = { - # 'Bucket': bucket, - # 'Prefix': prefix, - # } - # if continuation_token: - # list_objects_kwargs['ContinuationToken'] = continuation_token - - # response = self.s3_client.list_objects_v2(**list_objects_kwargs) - - # if 'Contents' in response: - # for obj in response['Contents']: - # all_files.append(obj['Key']) - - # if response['IsTruncated']: - # continuation_token = response['NextContinuationToken'] - # else: - # break - - # return all_files - - -if __name__ == "__main__": - raise NotImplementedError("This file is not meant to be run directly") - text = "Testing 123" - # ingest(text=text) +# metadatas: List[Dict[str, Any]] = [ +# { +# 'course_name': course_name, +# 's3_path': s3_path, +# 'pagenumber': page['page_number'] + 1, # +1 for human indexing +# 'timestamp': '', +# 'readable_filename': kwargs.get('readable_filename', page['readable_filename']), +# 'url': kwargs.get('url', ''), +# 'base_url': kwargs.get('base_url', ''), +# } for page in pdf_pages_OCRed +# ] +# pdf_texts = [page['text'] for page in pdf_pages_OCRed] +# self.posthog.capture('distinct_id_of_the_user', +# event='ocr_pdf_succeeded', +# properties={ +# 'course_name': course_name, +# 's3_path': s3_path, +# }) + +# has_words = any(text.strip() for text in pdf_texts) +# if not has_words: +# raise ValueError("Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text.") + +# success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in PDF ingest (with OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc() +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# def _ingest_single_txt(self, s3_path: str, course_name: str, **kwargs) -> str: +# """Ingest a single .txt or .md file from S3. +# Args: +# s3_path (str): A path to a .txt file in S3 +# course_name (str): The name of the course +# Returns: +# str: "Success" or an error message +# """ +# print("In text ingest, UTF-8") +# try: +# # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' +# response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) +# text = response['Body'].read().decode('utf-8') +# print("UTF-8 text to ignest (from s3)", text) +# text = [text] + +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# }] +# print("Prior to ingest", metadatas) + +# success_or_failure = self.split_and_upload(texts=text, metadatas=metadatas) +# return success_or_failure +# except Exception as e: +# err = f"❌❌ Error in (TXT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def _ingest_single_ppt(self, s3_path: str, course_name: str, **kwargs) -> str: +# """ +# Ingest a single .ppt or .pptx file from S3. +# """ +# try: +# with NamedTemporaryFile() as tmpfile: +# # download from S3 into pdf_tmpfile +# #print("in ingest PPTX") +# self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) + +# loader = UnstructuredPowerPointLoader(tmpfile.name) +# documents = loader.load() + +# texts = [doc.page_content for doc in documents] +# metadatas: List[Dict[str, Any]] = [{ +# 'course_name': course_name, +# 's3_path': s3_path, +# 'readable_filename': kwargs.get('readable_filename', +# Path(s3_path).name[37:]), +# 'pagenumber': '', +# 'timestamp': '', +# 'url': '', +# 'base_url': '', +# } for doc in documents] + +# self.split_and_upload(texts=texts, metadatas=metadatas) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (PPTX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( +# ) +# print(err) +# sentry_sdk.capture_exception(e) +# return str(err) + +# def ingest_github(self, github_url: str, course_name: str) -> str: +# """ +# Clones the given GitHub URL and uses Langchain to load data. +# 1. Clone the repo +# 2. Use Langchain to load the data +# 3. Pass to split_and_upload() +# Args: +# github_url (str): The Github Repo URL to be ingested. +# course_name (str): The name of the course in our system. + +# Returns: +# _type_: Success or error message. +# """ +# try: +# repo_path = "media/cloned_repo" +# repo = Repo.clone_from(github_url, to_path=repo_path, depth=1, clone_submodules=False) +# branch = repo.head.reference + +# loader = GitLoader(repo_path="media/cloned_repo", branch=str(branch)) +# data = loader.load() +# shutil.rmtree("media/cloned_repo") +# # create metadata for each file in data + +# for doc in data: +# texts = doc.page_content +# metadatas: Dict[str, Any] = { +# 'course_name': course_name, +# 's3_path': '', +# 'readable_filename': doc.metadata['file_name'], +# 'url': f"{github_url}/blob/main/{doc.metadata['file_path']}", +# 'pagenumber': '', +# 'timestamp': '', +# } +# self.split_and_upload(texts=[texts], metadatas=[metadatas]) +# return "Success" +# except Exception as e: +# err = f"❌❌ Error in (GITHUB ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# def split_and_upload(self, texts: List[str], metadatas: List[Dict[str, Any]]): +# """ This is usually the last step of document ingest. Chunk & upload to Qdrant (and Supabase.. todo). +# Takes in Text and Metadata (from Langchain doc loaders) and splits / uploads to Qdrant. + +# good examples here: https://langchain.readthedocs.io/en/latest/modules/utils/combine_docs_examples/textsplitter.html + +# Args: +# texts (List[str]): _description_ +# metadatas (List[Dict[str, Any]]): _description_ +# """ +# # return "Success" +# self.posthog.capture('distinct_id_of_the_user', +# event='split_and_upload_invoked', +# properties={ +# 'course_name': metadatas[0].get('course_name', None), +# 's3_path': metadatas[0].get('s3_path', None), +# 'readable_filename': metadatas[0].get('readable_filename', None), +# 'url': metadatas[0].get('url', None), +# 'base_url': metadatas[0].get('base_url', None), +# }) + +# print(f"In split and upload. Metadatas: {metadatas}") +# print(f"Texts: {texts}") +# assert len(texts) == len( +# metadatas +# ), f'must have equal number of text strings and metadata dicts. len(texts) is {len(texts)}. len(metadatas) is {len(metadatas)}' + +# try: +# text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( +# chunk_size=1000, +# chunk_overlap=150, +# separators=[ +# "\n\n", "\n", ". ", " ", "" +# ] # try to split on paragraphs... fallback to sentences, then chars, ensure we always fit in context window +# ) +# contexts: List[Document] = text_splitter.create_documents(texts=texts, metadatas=metadatas) +# input_texts = [{'input': context.page_content, 'model': 'text-embedding-ada-002'} for context in contexts] + +# # check for duplicates +# is_duplicate = self.check_for_duplicates(input_texts, metadatas) +# if is_duplicate: +# self.posthog.capture('distinct_id_of_the_user', +# event='split_and_upload_succeeded', +# properties={ +# 'course_name': metadatas[0].get('course_name', None), +# 's3_path': metadatas[0].get('s3_path', None), +# 'readable_filename': metadatas[0].get('readable_filename', None), +# 'url': metadatas[0].get('url', None), +# 'base_url': metadatas[0].get('base_url', None), +# 'is_duplicate': True, +# }) +# return "Success" + +# # adding chunk index to metadata for parent doc retrieval +# for i, context in enumerate(contexts): +# context.metadata['chunk_index'] = i + +# oai = OpenAIAPIProcessor( +# input_prompts_list=input_texts, +# request_url='https://api.openai.com/v1/embeddings', +# api_key=os.getenv('VLADS_OPENAI_KEY'), +# # request_url='https://uiuc-chat-canada-east.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-05-15', +# # api_key=os.getenv('AZURE_OPENAI_KEY'), +# max_requests_per_minute=5_000, +# max_tokens_per_minute=300_000, +# max_attempts=20, +# logging_level=logging.INFO, +# token_encoding_name='cl100k_base') # nosec -- reasonable bandit error suppression +# asyncio.run(oai.process_api_requests_from_file()) +# # parse results into dict of shape page_content -> embedding +# embeddings_dict: dict[str, List[float]] = { +# item[0]['input']: item[1]['data'][0]['embedding'] for item in oai.results +# } + +# ### BULK upload to Qdrant ### +# vectors: list[PointStruct] = [] +# for context in contexts: +# # !DONE: Updated the payload so each key is top level (no more payload.metadata.course_name. Instead, use payload.course_name), great for creating indexes. +# upload_metadata = {**context.metadata, "page_content": context.page_content} +# vectors.append( +# PointStruct(id=str(uuid.uuid4()), vector=embeddings_dict[context.page_content], payload=upload_metadata)) + +# self.qdrant_client.upsert( +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], # type: ignore +# points=vectors # type: ignore +# ) +# ### Supabase SQL ### +# contexts_for_supa = [{ +# "text": context.page_content, +# "pagenumber": context.metadata.get('pagenumber'), +# "timestamp": context.metadata.get('timestamp'), +# "chunk_index": context.metadata.get('chunk_index'), +# "embedding": embeddings_dict[context.page_content] +# } for context in contexts] + +# document = { +# "course_name": contexts[0].metadata.get('course_name'), +# "s3_path": contexts[0].metadata.get('s3_path'), +# "readable_filename": contexts[0].metadata.get('readable_filename'), +# "url": contexts[0].metadata.get('url'), +# "base_url": contexts[0].metadata.get('base_url'), +# "contexts": contexts_for_supa, +# } + +# response = self.supabase_client.table( +# os.getenv('SUPABASE_DOCUMENTS_TABLE')).insert(document).execute() # type: ignore + +# # add to Nomic document map +# if len(response.data) > 0: +# course_name = contexts[0].metadata.get('course_name') +# log_to_document_map(course_name) + +# self.posthog.capture('distinct_id_of_the_user', +# event='split_and_upload_succeeded', +# properties={ +# 'course_name': metadatas[0].get('course_name', None), +# 's3_path': metadatas[0].get('s3_path', None), +# 'readable_filename': metadatas[0].get('readable_filename', None), +# 'url': metadatas[0].get('url', None), +# 'base_url': metadatas[0].get('base_url', None), +# }) +# print("successful END OF split_and_upload") +# return "Success" +# except Exception as e: +# err: str = f"ERROR IN split_and_upload(): Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# def check_for_duplicates(self, texts: List[Dict], metadatas: List[Dict[str, Any]]) -> bool: +# """ +# For given metadata, fetch docs from Supabase based on S3 path or URL. +# If docs exists, concatenate the texts and compare with current texts, if same, return True. +# """ +# doc_table = os.getenv('SUPABASE_DOCUMENTS_TABLE', '') +# course_name = metadatas[0]['course_name'] +# incoming_s3_path = metadatas[0]['s3_path'] +# url = metadatas[0]['url'] +# original_filename = incoming_s3_path.split('/')[-1][37:] # remove the 37-char uuid prefix + +# # check if uuid exists in s3_path -- not all s3_paths have uuids! +# incoming_filename = incoming_s3_path.split('/')[-1] +# pattern = re.compile(r'[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}', +# re.I) # uuid V4 pattern, and v4 only. +# if bool(pattern.search(incoming_filename)): +# # uuid pattern exists -- remove the uuid and proceed with duplicate checking +# original_filename = incoming_filename[37:] +# else: +# # do not remove anything and proceed with duplicate checking +# original_filename = incoming_filename + +# if incoming_s3_path: +# filename = incoming_s3_path +# supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( +# 'course_name', course_name).like('s3_path', '%' + original_filename + '%').order('id', desc=True).execute() +# supabase_contents = supabase_contents.data +# elif url: +# filename = url +# supabase_contents = self.supabase_client.table(doc_table).select('id', 'contexts', 's3_path').eq( +# 'course_name', course_name).eq('url', url).order('id', desc=True).execute() +# supabase_contents = supabase_contents.data +# else: +# filename = None +# supabase_contents = [] + +# supabase_whole_text = "" +# if len(supabase_contents) > 0: # if a doc with same filename exists in Supabase +# # concatenate texts +# supabase_contexts = supabase_contents[0] +# for text in supabase_contexts['contexts']: +# supabase_whole_text += text['text'] + +# current_whole_text = "" +# for text in texts: +# current_whole_text += text['input'] + +# if supabase_whole_text == current_whole_text: # matches the previous file +# print(f"Duplicate ingested! 📄 s3_path: {filename}.") +# return True + +# else: # the file is updated +# print(f"Updated file detected! Same filename, new contents. 📄 s3_path: {filename}") + +# # call the delete function on older docs +# for content in supabase_contents: +# print("older s3_path to be deleted: ", content['s3_path']) +# delete_status = self.delete_data(course_name, content['s3_path'], '') +# print("delete_status: ", delete_status) +# return False + +# else: # filename does not already exist in Supabase, so its a brand new file +# print(f"NOT a duplicate! 📄s3_path: {filename}") +# return False + +# def delete_data(self, course_name: str, s3_path: str, source_url: str): +# """Delete file from S3, Qdrant, and Supabase.""" +# print(f"Deleting {s3_path} from S3, Qdrant, and Supabase for course {course_name}") +# # add delete from doc map logic here +# try: +# # Delete file from S3 +# bucket_name = os.getenv('S3_BUCKET_NAME') + +# # Delete files by S3 path +# if s3_path: +# try: +# self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path) +# except Exception as e: +# print("Error in deleting file from s3:", e) +# sentry_sdk.capture_exception(e) +# # Delete from Qdrant +# # docs for nested keys: https://qdrant.tech/documentation/concepts/filtering/#nested-key +# # Qdrant "points" look like this: Record(id='000295ca-bd28-ac4a-6f8d-c245f7377f90', payload={'metadata': {'course_name': 'zotero-extreme', 'pagenumber_or_timestamp': 15, 'readable_filename': 'Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf', 's3_path': 'courses/zotero-extreme/Dunlosky et al. - 2013 - Improving Students’ Learning With Effective Learni.pdf'}, 'page_content': '18 \nDunlosky et al.\n3.3 Effects in representative educational contexts. Sev-\neral of the large summarization-training studies have been \nconducted in regular classrooms, indicating the feasibility of \ndoing so. For example, the study by A. King (1992) took place \nin the context of a remedial study-skills course for undergrad-\nuates, and the study by Rinehart et al. (1986) took place in \nsixth-grade classrooms, with the instruction led by students \nregular teachers. In these and other cases, students benefited \nfrom the classroom training. We suspect it may actually be \nmore feasible to conduct these kinds of training ... +# try: +# self.qdrant_client.delete( +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], +# points_selector=models.Filter(must=[ +# models.FieldCondition( +# key="s3_path", +# match=models.MatchValue(value=s3_path), +# ), +# ]), +# ) +# except Exception as e: +# if "timed out" in str(e): +# # Timed out is fine. Still deletes. +# # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 +# pass +# else: +# print("Error in deleting file from Qdrant:", e) +# sentry_sdk.capture_exception(e) +# try: +# # delete from Nomic +# response = self.supabase_client.from_( +# os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq('s3_path', s3_path).eq( +# 'course_name', course_name).execute() +# data = response.data[0] #single record fetched +# nomic_ids_to_delete = [] +# context_count = len(data['contexts']) +# for i in range(1, context_count + 1): +# nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) + +# # delete from Nomic +# delete_from_document_map(course_name, nomic_ids_to_delete) +# except Exception as e: +# print("Error in deleting file from Nomic:", e) +# sentry_sdk.capture_exception(e) + +# try: +# self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( +# 'course_name', course_name).execute() +# except Exception as e: +# print("Error in deleting file from supabase:", e) +# sentry_sdk.capture_exception(e) + +# # Delete files by their URL identifier +# elif source_url: +# try: +# # Delete from Qdrant +# self.qdrant_client.delete( +# collection_name=os.environ['QDRANT_COLLECTION_NAME'], +# points_selector=models.Filter(must=[ +# models.FieldCondition( +# key="url", +# match=models.MatchValue(value=source_url), +# ), +# ]), +# ) +# except Exception as e: +# if "timed out" in str(e): +# # Timed out is fine. Still deletes. +# # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 +# pass +# else: +# print("Error in deleting file from Qdrant:", e) +# sentry_sdk.capture_exception(e) +# try: +# # delete from Nomic +# response = self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, url, contexts").eq( +# 'url', source_url).eq('course_name', course_name).execute() +# data = response.data[0] #single record fetched +# nomic_ids_to_delete = [] +# context_count = len(data['contexts']) +# for i in range(1, context_count + 1): +# nomic_ids_to_delete.append(str(data['id']) + "_" + str(i)) + +# # delete from Nomic +# delete_from_document_map(course_name, nomic_ids_to_delete) +# except Exception as e: +# print("Error in deleting file from Nomic:", e) +# sentry_sdk.capture_exception(e) + +# try: +# # delete from Supabase +# self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('url', source_url).eq( +# 'course_name', course_name).execute() +# except Exception as e: +# print("Error in deleting file from supabase:", e) +# sentry_sdk.capture_exception(e) + +# # Delete from Supabase +# return "Success" +# except Exception as e: +# err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# print(err) +# sentry_sdk.capture_exception(e) +# return err + +# # def ingest_coursera(self, coursera_course_name: str, course_name: str) -> str: +# # """ Download all the files from a coursera course and ingest them. + +# # 1. Download the coursera content. +# # 2. Upload to S3 (so users can view it) +# # 3. Run everything through the ingest_bulk method. + +# # Args: +# # coursera_course_name (str): The name of the coursera course. +# # course_name (str): The name of the course in our system. + +# # Returns: +# # _type_: Success or error message. +# # """ +# # certificate = "-ca 'FVhVoDp5cb-ZaoRr5nNJLYbyjCLz8cGvaXzizqNlQEBsG5wSq7AHScZGAGfC1nI0ehXFvWy1NG8dyuIBF7DLMA.X3cXsDvHcOmSdo3Fyvg27Q.qyGfoo0GOHosTVoSMFy-gc24B-_BIxJtqblTzN5xQWT3hSntTR1DMPgPQKQmfZh_40UaV8oZKKiF15HtZBaLHWLbpEpAgTg3KiTiU1WSdUWueo92tnhz-lcLeLmCQE2y3XpijaN6G4mmgznLGVsVLXb-P3Cibzz0aVeT_lWIJNrCsXrTFh2HzFEhC4FxfTVqS6cRsKVskPpSu8D9EuCQUwJoOJHP_GvcME9-RISBhi46p-Z1IQZAC4qHPDhthIJG4bJqpq8-ZClRL3DFGqOfaiu5y415LJcH--PRRKTBnP7fNWPKhcEK2xoYQLr9RxBVL3pzVPEFyTYtGg6hFIdJcjKOU11AXAnQ-Kw-Gb_wXiHmu63veM6T8N2dEkdqygMre_xMDT5NVaP3xrPbA4eAQjl9yov4tyX4AQWMaCS5OCbGTpMTq2Y4L0Mbz93MHrblM2JL_cBYa59bq7DFK1IgzmOjFhNG266mQlC9juNcEhc'" +# # always_use_flags = "-u kastanvday@gmail.com -p hSBsLaF5YM469# --ignore-formats mp4 --subtitle-language en --path ./coursera-dl" + +# # try: +# # subprocess.run( +# # f"coursera-dl {always_use_flags} {certificate} {coursera_course_name}", +# # check=True, +# # shell=True, # nosec -- reasonable bandit error suppression +# # stdout=subprocess.PIPE, +# # stderr=subprocess.PIPE) # capture_output=True, +# # dl_results_path = os.path.join('coursera-dl', coursera_course_name) +# # s3_paths: Union[List, None] = upload_data_files_to_s3(course_name, dl_results_path) + +# # if s3_paths is None: +# # return "Error: No files found in the coursera-dl directory" + +# # print("starting bulk ingest") +# # start_time = time.monotonic() +# # self.bulk_ingest(s3_paths, course_name) +# # print("completed bulk ingest") +# # print(f"⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") + +# # # Cleanup the coursera downloads +# # shutil.rmtree(dl_results_path) + +# # return "Success" +# # except Exception as e: +# # err: str = f"Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore +# # print(err) +# # return err + +# # def list_files_recursively(self, bucket, prefix): +# # all_files = [] +# # continuation_token = None + +# # while True: +# # list_objects_kwargs = { +# # 'Bucket': bucket, +# # 'Prefix': prefix, +# # } +# # if continuation_token: +# # list_objects_kwargs['ContinuationToken'] = continuation_token + +# # response = self.s3_client.list_objects_v2(**list_objects_kwargs) + +# # if 'Contents' in response: +# # for obj in response['Contents']: +# # all_files.append(obj['Key']) + +# # if response['IsTruncated']: +# # continuation_token = response['NextContinuationToken'] +# # else: +# # break + +# # return all_files + + +# if __name__ == "__main__": +# raise NotImplementedError("This file is not meant to be run directly") +# text = "Testing 123" +# # ingest(text=text) diff --git a/ai_ta_backend/beam/nomic_logging.py b/ai_ta_backend/beam/nomic_logging.py index 92db8a62..d15c616e 100644 --- a/ai_ta_backend/beam/nomic_logging.py +++ b/ai_ta_backend/beam/nomic_logging.py @@ -1,438 +1,438 @@ -import datetime -import os - -import nomic -import numpy as np -import pandas as pd -import sentry_sdk -import supabase -from langchain.embeddings import OpenAIEmbeddings -from nomic import AtlasProject, atlas - -OPENAI_API_TYPE = "azure" - -SUPABASE_CLIENT = supabase.create_client( # type: ignore - supabase_url=os.getenv('SUPABASE_URL'), # type: ignore - supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore - -NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - -## -------------------------------- DOCUMENT MAP FUNCTIONS --------------------------------- ## - -def create_document_map(course_name: str): - """ - This is a function which creates a document map for a given course from scratch - 1. Gets count of documents for the course - 2. If less than 20, returns a message that a map cannot be created - 3. If greater than 20, iteratively fetches documents in batches of 25 - 4. Prepares metadata and embeddings for nomic upload - 5. Creates a new map and uploads the data - - Args: - course_name: str - Returns: - str: success or failed - """ - print("in create_document_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) +# import datetime +# import os + +# import nomic +# import numpy as np +# import pandas as pd +# import sentry_sdk +# import supabase +# from langchain.embeddings import OpenAIEmbeddings +# from nomic import AtlasProject, atlas + +# OPENAI_API_TYPE = "azure" + +# SUPABASE_CLIENT = supabase.create_client( # type: ignore +# supabase_url=os.getenv('SUPABASE_URL'), # type: ignore +# supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore + +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' + +# ## -------------------------------- DOCUMENT MAP FUNCTIONS --------------------------------- ## + +# def create_document_map(course_name: str): +# """ +# This is a function which creates a document map for a given course from scratch +# 1. Gets count of documents for the course +# 2. If less than 20, returns a message that a map cannot be created +# 3. If greater than 20, iteratively fetches documents in batches of 25 +# 4. Prepares metadata and embeddings for nomic upload +# 5. Creates a new map and uploads the data + +# Args: +# course_name: str +# Returns: +# str: success or failed +# """ +# print("in create_document_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) - try: - # check if map exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - if response.data: - if response.data[0]['doc_map_id']: - return "Map already exists for this course." - - # fetch relevant document data from Supabase - response = SUPABASE_CLIENT.table("documents").select("id", - count="exact").eq("course_name", - course_name).order('id', - desc=False).execute() - if not response.count: - return "No documents found for this course." +# try: +# # check if map exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() +# if response.data: +# if response.data[0]['doc_map_id']: +# return "Map already exists for this course." + +# # fetch relevant document data from Supabase +# response = SUPABASE_CLIENT.table("documents").select("id", +# count="exact").eq("course_name", +# course_name).order('id', +# desc=False).execute() +# if not response.count: +# return "No documents found for this course." - total_doc_count = response.count - print("Total number of documents in Supabase: ", total_doc_count) +# total_doc_count = response.count +# print("Total number of documents in Supabase: ", total_doc_count) - # minimum 20 docs needed to create map - if total_doc_count < 20: - return "Cannot create a map because there are less than 20 documents in the course." +# # minimum 20 docs needed to create map +# if total_doc_count < 20: +# return "Cannot create a map because there are less than 20 documents in the course." - first_id = response.data[0]['id'] +# first_id = response.data[0]['id'] - combined_dfs = [] - curr_total_doc_count = 0 - doc_count = 0 - first_batch = True +# combined_dfs = [] +# curr_total_doc_count = 0 +# doc_count = 0 +# first_batch = True - # iteratively query in batches of 25 - while curr_total_doc_count < total_doc_count: +# # iteratively query in batches of 25 +# while curr_total_doc_count < total_doc_count: - response = SUPABASE_CLIENT.table("documents").select( - "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(25).execute() - df = pd.DataFrame(response.data) - combined_dfs.append(df) # list of dfs - - curr_total_doc_count += len(response.data) - doc_count += len(response.data) - - if doc_count >= 1000: # upload to Nomic in batches of 1000 - - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - - # prep data for nomic upload - embeddings, metadata = data_prep_for_doc_map(final_df) - - if first_batch: - # create a new map - print("Creating new map...") - project_name = NOMIC_MAP_NAME_PREFIX + course_name - index_name = course_name + "_doc_index" - topic_label_field = "text" - colorable_fields = ["readable_filename", "text", "base_url", "created_at"] - result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) +# response = SUPABASE_CLIENT.table("documents").select( +# "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( +# 'id', first_id).order('id', desc=False).limit(25).execute() +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) # list of dfs + +# curr_total_doc_count += len(response.data) +# doc_count += len(response.data) + +# if doc_count >= 1000: # upload to Nomic in batches of 1000 + +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) + +# # prep data for nomic upload +# embeddings, metadata = data_prep_for_doc_map(final_df) + +# if first_batch: +# # create a new map +# print("Creating new map...") +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# index_name = course_name + "_doc_index" +# topic_label_field = "text" +# colorable_fields = ["readable_filename", "text", "base_url", "created_at"] +# result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - if result == "success": - # update flag - first_batch = False - # log project info to supabase - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - last_id = int(final_df['id'].iloc[-1]) - project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} - project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() - if project_response.data: - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - else: - insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() - print("Insert Response from supabase: ", insert_response) +# if result == "success": +# # update flag +# first_batch = False +# # log project info to supabase +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} +# project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() +# if project_response.data: +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) +# else: +# insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() +# print("Insert Response from supabase: ", insert_response) - else: - # append to existing map - print("Appending data to existing map...") - project_name = NOMIC_MAP_NAME_PREFIX + course_name - # add project lock logic here - result = append_to_map(embeddings, metadata, project_name) - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) +# else: +# # append to existing map +# print("Appending data to existing map...") +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# # add project lock logic here +# result = append_to_map(embeddings, metadata, project_name) +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) - # reset variables - combined_dfs = [] - doc_count = 0 - print("Records uploaded: ", curr_total_doc_count) - - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 - - # upload last set of docs - if doc_count > 0: - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = data_prep_for_doc_map(final_df) - project_name = NOMIC_MAP_NAME_PREFIX + course_name - if first_batch: - index_name = course_name + "_doc_index" - topic_label_field = "text" - colorable_fields = ["readable_filename", "text", "base_url", "created_at"] - result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - else: - result = append_to_map(embeddings, metadata, project_name) - - # update the last uploaded id in supabase - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - project = AtlasProject(name=project_name, add_datums_if_exists=True) - project_id = project.id - project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} - print("project_info: ", project_info) - project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() - if project_response.data: - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) - else: - insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() - print("Insert Response from supabase: ", insert_response) +# # reset variables +# combined_dfs = [] +# doc_count = 0 +# print("Records uploaded: ", curr_total_doc_count) + +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 + +# # upload last set of docs +# if doc_count > 0: +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = data_prep_for_doc_map(final_df) +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# if first_batch: +# index_name = course_name + "_doc_index" +# topic_label_field = "text" +# colorable_fields = ["readable_filename", "text", "base_url", "created_at"] +# result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) +# else: +# result = append_to_map(embeddings, metadata, project_name) + +# # update the last uploaded id in supabase +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# project = AtlasProject(name=project_name, add_datums_if_exists=True) +# project_id = project.id +# project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} +# print("project_info: ", project_info) +# project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() +# if project_response.data: +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) +# else: +# insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() +# print("Insert Response from supabase: ", insert_response) - # rebuild the map - rebuild_map(course_name, "document") +# # rebuild the map +# rebuild_map(course_name, "document") - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "failed" - -def delete_from_document_map(course_name: str, ids: list): - """ - This function is used to delete datapoints from a document map. - Currently used within the delete_data() function in vector_database.py - Args: - course_name: str - ids: list of str - """ - print("in delete_from_document_map()") - - try: - # check if project exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - if response.data: - project_id = response.data[0]['doc_map_id'] - else: - return "No document map found for this course" - - # fetch project from Nomic - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - - # delete the ids from Nomic - print("Deleting point from document map:", project.delete_data(ids)) - with project.wait_for_project_lock(): - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "Error in deleting from document map: {e}" - - -def log_to_document_map(course_name: str): - """ - This is a function which appends new documents to an existing document map. It's called - at the end of split_and_upload() after inserting data to Supabase. - Args: - data: dict - the response data from Supabase insertion - """ - print("in add_to_document_map()") - - try: - # check if map exists - response = SUPABASE_CLIENT.table("projects").select("doc_map_id, last_uploaded_doc_id").eq("course_name", course_name).execute() - if response.data: - if response.data[0]['doc_map_id']: - project_id = response.data[0]['doc_map_id'] - last_uploaded_doc_id = response.data[0]['last_uploaded_doc_id'] - else: - # entry present in supabase, but doc map not present - create_document_map(course_name) - return "Document map not present, triggering map creation." - - else: - # create a map - create_document_map(course_name) - return "Document map not present, triggering map creation." +# except Exception as e: +# print(e) +# sentry_sdk.capture_exception(e) +# return "failed" + +# def delete_from_document_map(course_name: str, ids: list): +# """ +# This function is used to delete datapoints from a document map. +# Currently used within the delete_data() function in vector_database.py +# Args: +# course_name: str +# ids: list of str +# """ +# print("in delete_from_document_map()") + +# try: +# # check if project exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() +# if response.data: +# project_id = response.data[0]['doc_map_id'] +# else: +# return "No document map found for this course" + +# # fetch project from Nomic +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) + +# # delete the ids from Nomic +# print("Deleting point from document map:", project.delete_data(ids)) +# with project.wait_for_project_lock(): +# project.rebuild_maps() +# return "success" +# except Exception as e: +# print(e) +# sentry_sdk.capture_exception(e) +# return "Error in deleting from document map: {e}" + + +# def log_to_document_map(course_name: str): +# """ +# This is a function which appends new documents to an existing document map. It's called +# at the end of split_and_upload() after inserting data to Supabase. +# Args: +# data: dict - the response data from Supabase insertion +# """ +# print("in add_to_document_map()") + +# try: +# # check if map exists +# response = SUPABASE_CLIENT.table("projects").select("doc_map_id, last_uploaded_doc_id").eq("course_name", course_name).execute() +# if response.data: +# if response.data[0]['doc_map_id']: +# project_id = response.data[0]['doc_map_id'] +# last_uploaded_doc_id = response.data[0]['last_uploaded_doc_id'] +# else: +# # entry present in supabase, but doc map not present +# create_document_map(course_name) +# return "Document map not present, triggering map creation." + +# else: +# # create a map +# create_document_map(course_name) +# return "Document map not present, triggering map creation." - project = AtlasProject(project_id=project_id, add_datums_if_exists=True) - project_name = "Document Map for " + course_name +# project = AtlasProject(project_id=project_id, add_datums_if_exists=True) +# project_name = "Document Map for " + course_name - # check if project is LOCKED, if yes -> skip logging - if not project.is_accepting_data: - return "Skipping Nomic logging because project is locked." +# # check if project is LOCKED, if yes -> skip logging +# if not project.is_accepting_data: +# return "Skipping Nomic logging because project is locked." - # fetch count of records greater than last_uploaded_doc_id - print("last uploaded doc id: ", last_uploaded_doc_id) - response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() - print("Number of new documents: ", response.count) - - total_doc_count = response.count - current_doc_count = 0 - combined_dfs = [] - doc_count = 0 - first_id = last_uploaded_doc_id - while current_doc_count < total_doc_count: - # fetch all records from supabase greater than last_uploaded_doc_id - response = SUPABASE_CLIENT.table("documents").select("id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gt("id", first_id).limit(25).execute() - df = pd.DataFrame(response.data) - combined_dfs.append(df) # list of dfs - - current_doc_count += len(response.data) - doc_count += len(response.data) - - if doc_count >= 1000: # upload to Nomic in batches of 1000 - # concat all dfs from the combined_dfs list - final_df = pd.concat(combined_dfs, ignore_index=True) - # prep data for nomic upload - embeddings, metadata = data_prep_for_doc_map(final_df) - - # append to existing map - print("Appending data to existing map...") +# # fetch count of records greater than last_uploaded_doc_id +# print("last uploaded doc id: ", last_uploaded_doc_id) +# response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() +# print("Number of new documents: ", response.count) + +# total_doc_count = response.count +# current_doc_count = 0 +# combined_dfs = [] +# doc_count = 0 +# first_id = last_uploaded_doc_id +# while current_doc_count < total_doc_count: +# # fetch all records from supabase greater than last_uploaded_doc_id +# response = SUPABASE_CLIENT.table("documents").select("id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gt("id", first_id).limit(25).execute() +# df = pd.DataFrame(response.data) +# combined_dfs.append(df) # list of dfs + +# current_doc_count += len(response.data) +# doc_count += len(response.data) + +# if doc_count >= 1000: # upload to Nomic in batches of 1000 +# # concat all dfs from the combined_dfs list +# final_df = pd.concat(combined_dfs, ignore_index=True) +# # prep data for nomic upload +# embeddings, metadata = data_prep_for_doc_map(final_df) + +# # append to existing map +# print("Appending data to existing map...") - result = append_to_map(embeddings, metadata, project_name) - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) +# result = append_to_map(embeddings, metadata, project_name) +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) - # reset variables - combined_dfs = [] - doc_count = 0 - print("Records uploaded: ", current_doc_count) +# # reset variables +# combined_dfs = [] +# doc_count = 0 +# print("Records uploaded: ", current_doc_count) - # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 +# # set first_id for next iteration +# first_id = response.data[-1]['id'] + 1 - # upload last set of docs - if doc_count > 0: - final_df = pd.concat(combined_dfs, ignore_index=True) - embeddings, metadata = data_prep_for_doc_map(final_df) - result = append_to_map(embeddings, metadata, project_name) - - # update the last uploaded id in supabase - if result == "success": - # update the last uploaded id in supabase - last_id = int(final_df['id'].iloc[-1]) - project_info = {'last_uploaded_doc_id': last_id} - update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() - print("Response from supabase: ", update_response) +# # upload last set of docs +# if doc_count > 0: +# final_df = pd.concat(combined_dfs, ignore_index=True) +# embeddings, metadata = data_prep_for_doc_map(final_df) +# result = append_to_map(embeddings, metadata, project_name) + +# # update the last uploaded id in supabase +# if result == "success": +# # update the last uploaded id in supabase +# last_id = int(final_df['id'].iloc[-1]) +# project_info = {'last_uploaded_doc_id': last_id} +# update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() +# print("Response from supabase: ", update_response) - return "success" - except Exception as e: - print(e) - return "failed" +# return "success" +# except Exception as e: +# print(e) +# return "failed" -def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): - """ - Generic function to create a Nomic map from given parameters. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - map_name: str - index_name: str - topic_label_field: str - colorable_fields: list of str - """ - nomic.login(os.getenv('NOMIC_API_KEY')) - try: - project = atlas.map_embeddings(embeddings=embeddings, - data=metadata, - id_field="id", - build_topic_model=True, - topic_label_field=topic_label_field, - name=map_name, - colorable_fields=colorable_fields, - add_datums_if_exists=True) - project.create_index(name=index_name, build_topic_model=True) - return "success" - except Exception as e: - print(e) - return "Error in creating map: {e}" - -def append_to_map(embeddings, metadata, map_name): - """ - Generic function to append new data to an existing Nomic map. - Args: - embeddings: np.array of embeddings - metadata: pd.DataFrame of Nomic upload metadata - map_name: str - """ - - nomic.login(os.getenv('NOMIC_API_KEY')) - try: - project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) - with project.wait_for_project_lock(): - project.add_embeddings(embeddings=embeddings, data=metadata) - return "success" - except Exception as e: - print(e) - return "Error in appending to map: {e}" - -def data_prep_for_doc_map(df: pd.DataFrame): - """ - This function prepares embeddings and metadata for nomic upload in document map creation. - Args: - df: pd.DataFrame - the dataframe of documents from Supabase - Returns: - embeddings: np.array of embeddings - metadata: pd.DataFrame of metadata - """ - print("in data_prep_for_doc_map()") - - metadata = [] - embeddings = [] - - texts = [] - - for index, row in df.iterrows(): - current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") - if row['url'] == None: - row['url'] = "" - if row['base_url'] == None: - row['base_url'] = "" - # iterate through all contexts and create separate entries for each - context_count = 0 - for context in row['contexts']: - context_count += 1 - text_row = context['text'] - embeddings_row = context['embedding'] - - meta_row = { - "id": str(row['id']) + "_" + str(context_count), - "created_at": created_at, - "s3_path": row['s3_path'], - "url": row['url'], - "base_url": row['base_url'], - "readable_filename": row['readable_filename'], - "modified_at": current_time, - "text": text_row - } - - embeddings.append(embeddings_row) - metadata.append(meta_row) - texts.append(text_row) - - embeddings_np = np.array(embeddings, dtype=object) - print("Shape of embeddings: ", embeddings_np.shape) - - # check dimension if embeddings_np is (n, 1536) - if len(embeddings_np.shape) < 2: - print("Creating new embeddings...") - - embeddings_model = OpenAIEmbeddings(openai_api_type="openai", - openai_api_base="https://api.openai.com/v1/", - openai_api_key=os.getenv('VLADS_OPENAI_KEY')) # type: ignore - embeddings = embeddings_model.embed_documents(texts) - - metadata = pd.DataFrame(metadata) - embeddings = np.array(embeddings) - - return embeddings, metadata - -def rebuild_map(course_name:str, map_type:str): - """ - This function rebuilds a given map in Nomic. - """ - print("in rebuild_map()") - nomic.login(os.getenv('NOMIC_API_KEY')) - if map_type.lower() == 'document': - NOMIC_MAP_NAME_PREFIX = 'Document Map for ' - else: - NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' - - try: - # fetch project from Nomic - project_name = NOMIC_MAP_NAME_PREFIX + course_name - project = AtlasProject(name=project_name, add_datums_if_exists=True) - - if project.is_accepting_data: # temporary fix - will skip rebuilding if project is locked - project.rebuild_maps() - return "success" - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - return "Error in rebuilding map: {e}" - - - -if __name__ == '__main__': - pass +# def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): +# """ +# Generic function to create a Nomic map from given parameters. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# map_name: str +# index_name: str +# topic_label_field: str +# colorable_fields: list of str +# """ +# nomic.login(os.getenv('NOMIC_API_KEY')) +# try: +# project = atlas.map_embeddings(embeddings=embeddings, +# data=metadata, +# id_field="id", +# build_topic_model=True, +# topic_label_field=topic_label_field, +# name=map_name, +# colorable_fields=colorable_fields, +# add_datums_if_exists=True) +# project.create_index(name=index_name, build_topic_model=True) +# return "success" +# except Exception as e: +# print(e) +# return "Error in creating map: {e}" + +# def append_to_map(embeddings, metadata, map_name): +# """ +# Generic function to append new data to an existing Nomic map. +# Args: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of Nomic upload metadata +# map_name: str +# """ + +# nomic.login(os.getenv('NOMIC_API_KEY')) +# try: +# project = atlas.AtlasProject(name=map_name, add_datums_if_exists=True) +# with project.wait_for_project_lock(): +# project.add_embeddings(embeddings=embeddings, data=metadata) +# return "success" +# except Exception as e: +# print(e) +# return "Error in appending to map: {e}" + +# def data_prep_for_doc_map(df: pd.DataFrame): +# """ +# This function prepares embeddings and metadata for nomic upload in document map creation. +# Args: +# df: pd.DataFrame - the dataframe of documents from Supabase +# Returns: +# embeddings: np.array of embeddings +# metadata: pd.DataFrame of metadata +# """ +# print("in data_prep_for_doc_map()") + +# metadata = [] +# embeddings = [] + +# texts = [] + +# for index, row in df.iterrows(): +# current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") +# created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") +# if row['url'] == None: +# row['url'] = "" +# if row['base_url'] == None: +# row['base_url'] = "" +# # iterate through all contexts and create separate entries for each +# context_count = 0 +# for context in row['contexts']: +# context_count += 1 +# text_row = context['text'] +# embeddings_row = context['embedding'] + +# meta_row = { +# "id": str(row['id']) + "_" + str(context_count), +# "created_at": created_at, +# "s3_path": row['s3_path'], +# "url": row['url'], +# "base_url": row['base_url'], +# "readable_filename": row['readable_filename'], +# "modified_at": current_time, +# "text": text_row +# } + +# embeddings.append(embeddings_row) +# metadata.append(meta_row) +# texts.append(text_row) + +# embeddings_np = np.array(embeddings, dtype=object) +# print("Shape of embeddings: ", embeddings_np.shape) + +# # check dimension if embeddings_np is (n, 1536) +# if len(embeddings_np.shape) < 2: +# print("Creating new embeddings...") + +# embeddings_model = OpenAIEmbeddings(openai_api_type="openai", +# openai_api_base="https://api.openai.com/v1/", +# openai_api_key=os.getenv('VLADS_OPENAI_KEY')) # type: ignore +# embeddings = embeddings_model.embed_documents(texts) + +# metadata = pd.DataFrame(metadata) +# embeddings = np.array(embeddings) + +# return embeddings, metadata + +# def rebuild_map(course_name:str, map_type:str): +# """ +# This function rebuilds a given map in Nomic. +# """ +# print("in rebuild_map()") +# nomic.login(os.getenv('NOMIC_API_KEY')) +# if map_type.lower() == 'document': +# NOMIC_MAP_NAME_PREFIX = 'Document Map for ' +# else: +# NOMIC_MAP_NAME_PREFIX = 'Conversation Map for ' + +# try: +# # fetch project from Nomic +# project_name = NOMIC_MAP_NAME_PREFIX + course_name +# project = AtlasProject(name=project_name, add_datums_if_exists=True) + +# if project.is_accepting_data: # temporary fix - will skip rebuilding if project is locked +# project.rebuild_maps() +# return "success" +# except Exception as e: +# print(e) +# sentry_sdk.capture_exception(e) +# return "Error in rebuilding map: {e}" + + + +# if __name__ == '__main__': +# pass diff --git a/ai_ta_backend/database/database_impl/storage/aws.py b/ai_ta_backend/database/aws.py similarity index 91% rename from ai_ta_backend/database/database_impl/storage/aws.py rename to ai_ta_backend/database/aws.py index a7042116..1e6f397d 100644 --- a/ai_ta_backend/database/database_impl/storage/aws.py +++ b/ai_ta_backend/database/aws.py @@ -3,10 +3,8 @@ import boto3 from injector import inject -from ai_ta_backend.database.base_storage import BaseStorageDatabase - -class AWSStorage(BaseStorageDatabase): +class AWSStorage(): @inject def __init__(self): diff --git a/ai_ta_backend/database/base_sql.py b/ai_ta_backend/database/base_sql.py deleted file mode 100644 index ce808e1a..00000000 --- a/ai_ta_backend/database/base_sql.py +++ /dev/null @@ -1,69 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple -from postgrest.base_request_builder import APIResponse - -class BaseSQLDatabase(ABC): - - @abstractmethod - def getAllMaterialsForCourse(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - pass - @abstractmethod - def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - pass - @abstractmethod - def getProjectsMapForCourse(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str, table_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getAllFromTableForDownloadType(self, course_name: str, download_type: str, first_id: int) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def insertProjectInfo(self, project_info) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getAllFromLLMConvoMonitor(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getDocMapFromProjects(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getConvoMapFromProjects(self, course_name: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def updateProjects(self, course_name: str, data: dict) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getLatestWorkflowId(self) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def lockWorkflow(self, id: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def deleteLatestWorkflowId(self, id: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def unlockWorkflow(self, id: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass - @abstractmethod - def getConversation(self, course_name: str, key: str, value: str) -> APIResponse[Tuple[Dict[str, Any], int]]: - pass \ No newline at end of file diff --git a/ai_ta_backend/database/base_storage.py b/ai_ta_backend/database/base_storage.py deleted file mode 100644 index 733b387d..00000000 --- a/ai_ta_backend/database/base_storage.py +++ /dev/null @@ -1,15 +0,0 @@ -from abc import ABC, abstractmethod - -class BaseStorageDatabase(ABC): - @abstractmethod - def upload_file(self, file_path: str, bucket_name: str, object_name: str): - pass - @abstractmethod - def download_file(self, object_name: str, bucket_name: str, file_path: str): - pass - @abstractmethod - def delete_file(self, bucket_name: str, s3_path: str): - pass - @abstractmethod - def generatePresignedUrl(self, object: str, bucket_name: str, s3_path: str, expiration: int = 3600): - pass \ No newline at end of file diff --git a/ai_ta_backend/database/base_vector.py b/ai_ta_backend/database/base_vector.py deleted file mode 100644 index 551fd328..00000000 --- a/ai_ta_backend/database/base_vector.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List - -class BaseVectorDatabase(ABC): - - @abstractmethod - def vector_search(self, search_query, course_name, doc_groups: List[str], user_query_embedding, top_n): - pass - - @abstractmethod - def delete_data(self, collection_name: str, key: str, value: str): - pass diff --git a/ai_ta_backend/database/database_impl/__init__.py b/ai_ta_backend/database/database_impl/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ai_ta_backend/database/database_impl/vector/qdrant.py b/ai_ta_backend/database/qdrant.py similarity index 96% rename from ai_ta_backend/database/database_impl/vector/qdrant.py rename to ai_ta_backend/database/qdrant.py index b6cc6bd5..725cb8fc 100644 --- a/ai_ta_backend/database/database_impl/vector/qdrant.py +++ b/ai_ta_backend/database/qdrant.py @@ -6,12 +6,11 @@ from langchain.vectorstores import Qdrant from qdrant_client import QdrantClient, models -from ai_ta_backend.database.base_vector import BaseVectorDatabase OPENAI_API_TYPE = "azure" # "openai" or "azure" -class VectorDatabase(BaseVectorDatabase): +class VectorDatabase(): """ Contains all methods for building and using vector databases. """ diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py new file mode 100644 index 00000000..f1474a9f --- /dev/null +++ b/ai_ta_backend/database/sql.py @@ -0,0 +1,230 @@ +from typing import List +from injector import inject +from flask_sqlalchemy import SQLAlchemy +import ai_ta_backend.model.models as models +import logging + +from ai_ta_backend.model.response import DatabaseResponse + +class SQLAlchemyDatabase: + + @inject + def __init__(self, db: SQLAlchemy): + logging.info("Initializing SQLAlchemyDatabase") + self.db = db + # Ensure an app context is pushed (Flask-Injector will handle this) + # with current_app.app_context(): + # db.init_app(current_app) + # db.create_all() # Create tables + + def getAllMaterialsForCourse(self, course_name: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + try: + query = self.db.delete(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + try: + query = self.db.delete(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def getProjectsMapForCourse(self, course_name: str): + try: + query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + projects: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=projects, count=len(result)) + finally: + self.db.session.close() + + def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name) + if from_date: + query = query.filter(models.Document.created_at >= from_date) + if to_date: + query = query.filter(models.Document.created_at <= to_date) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getConversationsBetweenDates(self, course_name: str, from_date: str, to_date: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) + if from_date: + query = query.filter(models.LlmConvoMonitor.created_at >= from_date) + if to_date: + query = query.filter(models.LlmConvoMonitor.created_at <= to_date) + result = self.db.session.execute(query).scalars().all() + documents: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getAllDocumentsForDownload(self, course_name: str, first_id: int): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.id >= first_id) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getAllConversationsForDownload(self, course_name: str, first_id: int): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id >= first_id) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id > first_id) + if last_id != 0: + query = query.filter(models.LlmConvoMonitor.id <= last_id) + query = query.limit(limit) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): + try: + fields_to_select = [getattr(models.Document, field) for field in fields.split(", ")] + query = self.db.select(*fields_to_select).where(models.Document.course_name == course_name, models.Document.id >= first_id).limit(limit) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def insertProjectInfo(self, project_info): + try: + self.db.session.execute(self.db.insert(models.Project).values(**project_info)) + self.db.session.commit() + finally: + self.db.session.close() + + def getAllFromLLMConvoMonitor(self, course_name: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int): + try: + query = self.db.select(models.LlmConvoMonitor.id).where(models.LlmConvoMonitor.course_name == course_name) + if last_id != 0: + query = query.filter(models.LlmConvoMonitor.id > last_id) + count_query = self.db.select(self.db.func.count()).select_from(query.subquery()) + count = self.db.session.execute(count_query).scalar() + return DatabaseResponse[models.LlmConvoMonitor](data=[], count=1) + finally: + self.db.session.close() + + def getDocMapFromProjects(self, course_name: str): + try: + query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getConvoMapFromProjects(self, course_name: str): + try: + query = self.db.select(models.Project).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def updateProjects(self, course_name: str, data: dict): + try: + query = self.db.update(models.Project).where(models.Project.course_name == course_name).values(**data) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def getLatestWorkflowId(self): + try: + query = self.db.select(models.N8nWorkflows.latest_workflow_id) + result = self.db.session.execute(query).fetchone() + return result + finally: + self.db.session.close() + + def lockWorkflow(self, id: str): + try: + new_workflow = models.N8nWorkflows(is_locked=True) + self.db.session.add(new_workflow) + self.db.session.commit() + finally: + self.db.session.close() + + def deleteLatestWorkflowId(self, id: str): + try: + query = self.db.delete(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def unlockWorkflow(self, id: str): + try: + query = self.db.update(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id).values(is_locked=False) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def getConversation(self, course_name: str, key: str, value: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(getattr(models.LlmConvoMonitor, key) == value) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() \ No newline at end of file diff --git a/ai_ta_backend/database/database_impl/sql/supabase.py b/ai_ta_backend/database/supabase.py similarity index 98% rename from ai_ta_backend/database/database_impl/sql/supabase.py rename to ai_ta_backend/database/supabase.py index a20ff756..b0bbbdb8 100644 --- a/ai_ta_backend/database/database_impl/sql/supabase.py +++ b/ai_ta_backend/database/supabase.py @@ -3,10 +3,7 @@ import supabase from injector import inject -from ai_ta_backend.database.base_sql import BaseSQLDatabase - - -class SQLDatabase(BaseSQLDatabase): +class SQLDatabase(): @inject def __init__(self): diff --git a/ai_ta_backend/extensions.py b/ai_ta_backend/extensions.py new file mode 100644 index 00000000..589c64fc --- /dev/null +++ b/ai_ta_backend/extensions.py @@ -0,0 +1,2 @@ +from flask_sqlalchemy import SQLAlchemy +db = SQLAlchemy() \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 82901304..6c3237da 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,4 +1,5 @@ import json +import logging import os import time from typing import List @@ -17,9 +18,9 @@ from flask_executor import Executor from flask_injector import FlaskInjector, RequestScope from injector import Binder, SingletonScope -from ai_ta_backend.database.base_sql import BaseSQLDatabase -from ai_ta_backend.database.base_storage import BaseStorageDatabase -from ai_ta_backend.database.base_vector import BaseVectorDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase +from ai_ta_backend.database.aws import AWSStorage +from ai_ta_backend.database.qdrant import VectorDatabase from ai_ta_backend.executors.flask_executor import ( ExecutorInterface, @@ -39,8 +40,9 @@ from ai_ta_backend.service.retrieval_service import RetrievalService from ai_ta_backend.service.sentry_service import SentryService -from ai_ta_backend.beam.nomic_logging import create_document_map +# from ai_ta_backend.beam.nomic_logging import create_document_map from ai_ta_backend.service.workflow_service import WorkflowService +from ai_ta_backend.extensions import db app = Flask(__name__) CORS(app) @@ -51,7 +53,6 @@ # load API keys from globally-availabe .env file load_dotenv() - @app.route('/') def index() -> Response: """_summary_ @@ -191,19 +192,19 @@ def nomic_map(service: NomicService): return response -@app.route('/createDocumentMap', methods=['GET']) -def createDocumentMap(service: NomicService): - course_name: str = request.args.get('course_name', default='', type=str) +# @app.route('/createDocumentMap', methods=['GET']) +# def createDocumentMap(service: NomicService): +# course_name: str = request.args.get('course_name', default='', type=str) - if course_name == '': - # proper web error "400 Bad request" - abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") +# if course_name == '': +# # proper web error "400 Bad request" +# abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") - map_id = create_document_map(course_name) +# map_id = create_document_map(course_name) - response = jsonify(map_id) - response.headers.add('Access-Control-Allow-Origin', '*') - return response +# response = jsonify(map_id) +# response.headers.add('Access-Control-Allow-Origin', '*') +# return response @app.route('/createConversationMap', methods=['GET']) def createConversationMap(service: NomicService): @@ -476,44 +477,68 @@ def configure(binder: Binder) -> None: vector_bound = False sql_bound = False storage_bound = False + + # Define database URLs with conditional checks for environment variables + DB_URLS = { + 'supabase': f"postgresql://{os.getenv('SUPABASE_KEY')}@{os.getenv('SUPABASE_URL')}" if os.getenv('SUPABASE_KEY') and os.getenv('SUPABASE_URL') else None, + 'sqlite': f"sqlite:///{os.getenv('SQLITE_DB_NAME')}" if os.getenv('SQLITE_DB_NAME') else None, + 'postgres': f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@{os.getenv('POSTGRES_URL')}" if os.getenv('POSTGRES_USER') and os.getenv('POSTGRES_PASSWORD') and os.getenv('POSTGRES_URL') else None + } + + # Bind to the first available SQL database configuration + for db_type, url in DB_URLS.items(): + if url: + logging.info(f"Binding to {db_type} database with URL: {url}") + with app.app_context(): + app.config['SQLALCHEMY_DATABASE_URI'] = url + db.init_app(app) + db.create_all() + binder.bind(SQLAlchemyDatabase, to=db, scope=SingletonScope) + sql_bound = True + break # Conditionally bind databases based on the availability of their respective secrets - if any(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any(os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): - binder.bind(BaseVectorDatabase, to=BaseVectorDatabase, scope=SingletonScope) + if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any(os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): + logging.info("Binding to Qdrant database") + binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) vector_bound = True - - if any(os.getenv(key) for key in ["SUPABASE_URL", "SUPABASE_API_KEY", "SUPABASE_DOCUMENTS_TABLE"]) or any(["SQLITE_DB_PATH", "SQLITE_DB_NAME", "SQLITE_DOCUMENTS_TABLE"]): - binder.bind(BaseSQLDatabase, to=BaseSQLDatabase, scope=SingletonScope) - sql_bound = True - if any(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]) or any(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): - binder.bind(BaseStorageDatabase, to=BaseStorageDatabase, scope=SingletonScope) + if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_BUCKET_NAME"]) or any(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): + logging.info("Binding to AWS S3 storage") + binder.bind(AWSStorage, to=AWSStorage, scope=SingletonScope) storage_bound = True # Conditionally bind services based on the availability of their respective secrets if os.getenv("NOMIC_API_KEY"): + logging.info("Binding to Nomic service") binder.bind(NomicService, to=NomicService, scope=SingletonScope) if os.getenv("POSTHOG_API_KEY"): + logging.info("Binding to Posthog service") binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) if os.getenv("SENTRY_DSN"): + logging.info("Binding to Sentry service") binder.bind(SentryService, to=SentryService, scope=SingletonScope) if os.getenv("EMAIL_SENDER"): + logging.info("Binding to Export service") binder.bind(ExportService, to=ExportService, scope=SingletonScope) if os.getenv("N8N_URL"): + logging.info("Binding to Workflow service") binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) if vector_bound and sql_bound and storage_bound: + logging.info("Binding to Retrieval service") binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) # Always bind the executor and its adapters binder.bind(ExecutorInterface, to=FlaskExecutorAdapter(executor), scope=SingletonScope) binder.bind(ThreadPoolExecutorInterface, to=ThreadPoolExecutorAdapter, scope=SingletonScope) binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter, scope=SingletonScope) + logging.info("Configured all services and adapters", binder._bindings) FlaskInjector(app=app, modules=[configure]) diff --git a/ai_ta_backend/model/models.py b/ai_ta_backend/model/models.py new file mode 100644 index 00000000..6b35ab44 --- /dev/null +++ b/ai_ta_backend/model/models.py @@ -0,0 +1,89 @@ +from sqlalchemy import Column, BigInteger, Text, DateTime, Boolean, ForeignKey, Index, JSON +from sqlalchemy.sql import func +from ai_ta_backend.extensions import db + +class Base(db.Model): + __abstract__ = True + +class Document(Base): + __tablename__ = 'documents' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + s3_path = Column(Text) + readable_filename = Column(Text) + course_name = Column(Text) + url = Column(Text) + contexts = Column(JSON, default=lambda: [ + { + "text": "", + "timestamp": "", + "embedding": "", + "pagenumber": "" + } + ]) + base_url = Column(Text) + + __table_args__ = ( + Index('documents_course_name_idx', 'course_name', postgresql_using='hash'), + Index('documents_created_at_idx', 'created_at', postgresql_using='btree'), + Index('idx_doc_s3_path', 's3_path', postgresql_using='btree'), + ) + +class DocumentDocGroup(Base): + __tablename__ = 'documents_doc_groups' + document_id = Column(BigInteger, primary_key=True) + doc_group_id = Column(BigInteger, ForeignKey('doc_groups.id', ondelete='CASCADE'), primary_key=True) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('documents_doc_groups_doc_group_id_idx', 'doc_group_id', postgresql_using='btree'), + Index('documents_doc_groups_document_id_idx', 'document_id', postgresql_using='btree'), + ) + +class DocGroup(Base): + __tablename__ = 'doc_groups' + id = Column(BigInteger, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False) + course_name = Column(Text, nullable=False) + created_at = Column(DateTime, default=func.now()) + enabled = Column(Boolean, default=True) + private = Column(Boolean, default=True) + doc_count = Column(BigInteger) + + __table_args__ = ( + Index('doc_groups_enabled_course_name_idx', 'enabled', 'course_name', postgresql_using='btree'), + ) + +class Project(Base): + __tablename__ = 'projects' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + course_name = Column(Text) + doc_map_id = Column(Text) + convo_map_id = Column(Text) + n8n_api_key = Column(Text) + last_uploaded_doc_id = Column(BigInteger) + last_uploaded_convo_id = Column(BigInteger) + subscribed = Column(BigInteger, ForeignKey('doc_groups.id', onupdate='CASCADE', ondelete='SET NULL')) + +class N8nWorkflows(Base): + __tablename__ = 'n8n_workflows' + latest_workflow_id = Column(BigInteger, primary_key=True, autoincrement=True) + is_locked = Column(Boolean, nullable=False) + + def __init__(self, is_locked: bool): + self.is_locked = is_locked + +class LlmConvoMonitor(Base): + __tablename__ = 'llm_convo_monitor' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + convo = Column(JSON) + convo_id = Column(Text, unique=True) + course_name = Column(Text) + user_email = Column(Text) + + __table_args__ = ( + Index('llm_convo_monitor_course_name_idx', 'course_name', postgresql_using='hash'), + Index('llm_convo_monitor_convo_id_idx', 'convo_id', postgresql_using='hash'), + ) \ No newline at end of file diff --git a/ai_ta_backend/model/response.py b/ai_ta_backend/model/response.py new file mode 100644 index 00000000..2263d8e3 --- /dev/null +++ b/ai_ta_backend/model/response.py @@ -0,0 +1,9 @@ +from typing import List, TypeVar, Generic +from flask_sqlalchemy.model import Model + +T = TypeVar('T', bound=Model) + +class DatabaseResponse(Generic[T]): + def __init__(self, data: List[T], count: int): + self.data = data + self.count = count \ No newline at end of file diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index 15c33ccd..30a959c6 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -8,16 +8,17 @@ import requests from injector import inject -from ai_ta_backend.database.database_impl.storage.aws import AWSStorage -from ai_ta_backend.database.database_impl.sql.supabase import SQLDatabase +from ai_ta_backend.database.aws import AWSStorage +from ai_ta_backend.database.sql import SQLAlchemyDatabase from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.emails import send_email +from ai_ta_backend.extensions import db class ExportService: @inject - def __init__(self, sql: SQLDatabase, s3: AWSStorage, sentry: SentryService): + def __init__(self, sql: SQLAlchemyDatabase, s3: AWSStorage, sentry: SentryService): self.sql = sql self.s3 = s3 self.sentry = sentry @@ -33,7 +34,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): to_date (str, optional): The end date for the data export. Defaults to ''. """ - response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'documents') + response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date) # add a condition to route to direct download or s3 download if response.count and response.count > 500: @@ -51,8 +52,8 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): if response.count and response.count > 0: # batch download total_doc_count = response.count - first_id = response.data[0]['id'] - last_id = response.data[-1]['id'] + first_id = int(str(response.data[0].id)) + last_id = int(str(response.data[-1].id)) print("total_doc_count: ", total_doc_count) print("first_id: ", first_id) @@ -76,7 +77,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): df.to_json(file_path, orient='records', lines=True, mode='a') if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # Download file try: @@ -106,7 +107,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): """ print("Exporting conversation history to json file...") - response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'llm-convo-monitor') + response = self.sql.getConversationsBetweenDates(course_name, from_date, to_date) if response.count > 500: # call background task to upload to s3 @@ -120,8 +121,8 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): # Fetch data if response.count > 0: print("id count greater than zero") - first_id = response.data[0]['id'] - last_id = response.data[-1]['id'] + first_id = int(str(response.data[0].id)) + last_id = int(str(response.data[-1].id)) total_count = response.count filename = course_name + '_' + str(uuid.uuid4()) + '_convo_history.jsonl' @@ -143,7 +144,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): # Update first_id if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 print("updated first_id: ", first_id) # Download file @@ -170,7 +171,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e """ print("Exporting conversation history to json file...") - response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date, 'llm-convo-monitor') + response = self.sql.getConversationsBetweenDates(course_name, from_date, to_date) if response.count > 500: # call background task to upload to s3 @@ -184,8 +185,8 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e # Fetch data if response.count > 0: print("id count greater than zero") - first_id = response.data[0]['id'] - last_id = response.data[-1]['id'] + first_id = int(str(response.data[0].id)) + last_id = int(str(response.data[-1].id)) total_count = response.count filename = course_name + '_' + str(uuid.uuid4()) + '_convo_history.jsonl' @@ -207,7 +208,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e # Update first_id if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 print("updated first_id: ", first_id) # Download file @@ -245,7 +246,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): s3_path (str): The S3 path where the file will be uploaded. """ s3 = AWSStorage() - sql = SQLDatabase() + sql = SQLAlchemyDatabase(db) total_doc_count = response.count first_id = response.data[0]['id'] @@ -259,7 +260,10 @@ def export_data_in_bg(response, download_type, course_name, s3_path): # download data in batches of 100 while curr_doc_count < total_doc_count: print("Fetching data from id: ", first_id) - response = sql.getAllFromTableForDownloadType(course_name, download_type, first_id) + if download_type == "documents": + response = sql.getAllDocumentsForDownload(course_name, first_id) + else: + response = sql.getAllConversationsForDownload(course_name, first_id) df = pd.DataFrame(response.data) curr_doc_count += len(response.data) @@ -270,7 +274,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): df.to_json(file_path, orient='records', lines=True, mode='a') if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # zip file zip_filename = filename.split('.')[0] + '.zip' @@ -354,7 +358,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai s3_path (str): The S3 path where the file will be uploaded. """ s3 = AWSStorage() - sql = SQLDatabase() + sql = SQLAlchemyDatabase(db) total_doc_count = response.count first_id = response.data[0]['id'] @@ -368,7 +372,10 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai # download data in batches of 100 while curr_doc_count < total_doc_count: print("Fetching data from id: ", first_id) - response = sql.getAllFromTableForDownloadType(course_name, download_type, first_id) + if download_type == "documents": + response = sql.getAllDocumentsForDownload(course_name, first_id) + else: + response = sql.getAllConversationsForDownload(course_name, first_id) df = pd.DataFrame(response.data) curr_doc_count += len(response.data) @@ -379,7 +386,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai df.to_json(file_path, orient='records', lines=True, mode='a') if len(response.data) > 0: - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # zip file zip_filename = filename.split('.')[0] + '.zip' diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index a0dedd45..1900279c 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -10,8 +10,8 @@ from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings from nomic import AtlasProject, atlas +from ai_ta_backend.database.sql import SQLAlchemyDatabase -from ai_ta_backend.database.database_impl.sql.supabase import SQLDatabase from ai_ta_backend.service.sentry_service import SentryService LOCK_EXCEPTIONS = [ @@ -24,7 +24,7 @@ class NomicService(): @inject - def __init__(self, sentry: SentryService, sql: SQLDatabase): + def __init__(self, sentry: SentryService, sql: SQLAlchemyDatabase): nomic.login(os.environ['NOMIC_API_KEY']) self.sentry = sentry self.sql = sql @@ -84,12 +84,12 @@ def log_to_conversation_map(self, course_name: str, conversation): return self.create_conversation_map(course_name) # entry present for doc map, but not convo map - elif not response.data[0]['convo_map_id']: + elif not response.data[0].convo_map_id is None: print("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) - project_id = response.data[0]['convo_map_id'] - last_uploaded_convo_id = response.data[0]['last_uploaded_convo_id'] + project_id = response.data[0].convo_map_id + last_uploaded_convo_id: int = int(str(response.data[0].last_uploaded_convo_id)) # check if project is accepting data project = AtlasProject(project_id=project_id, add_datums_if_exists=True) @@ -141,7 +141,7 @@ def log_to_conversation_map(self, course_name: str, conversation): print("Records uploaded: ", current_convo_count) # set first_id for next iteration - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 # upload last set of convos if convo_count > 0: @@ -180,7 +180,7 @@ def log_to_existing_conversation(self, course_name: str, conversation): project_name = 'Conversation Map for ' + course_name project = AtlasProject(name=project_name, add_datums_if_exists=True) - prev_id = incoming_id_response.data[0]['id'] + prev_id = str(incoming_id_response.data[0].id) uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic prev_convo = uploaded_data[0]['conversation'] @@ -248,7 +248,7 @@ def create_conversation_map(self, course_name: str): response = self.sql.getConvoMapFromProjects(course_name) print("Response from supabase: ", response.data) if response.data: - if response.data[0]['convo_map_id']: + if response.data[0].convo_map_id is not None: return "Map already exists for this course." # if no, fetch total count of records @@ -264,7 +264,7 @@ def create_conversation_map(self, course_name: str): total_convo_count = response.count print("Total number of conversations in Supabase: ", total_convo_count) - first_id = response.data[0]['id'] - 1 + first_id = int(str(response.data[0].id)) - 1 combined_dfs = [] current_convo_count = 0 convo_count = 0 @@ -332,10 +332,10 @@ def create_conversation_map(self, course_name: str): # set first_id for next iteration try: - print("response: ", response.data[-1]['id']) + print("response: ", response.data[-1].id) except: print("response: ", response.data) - first_id = response.data[-1]['id'] + 1 + first_id = int(str(response.data[-1].id)) + 1 print("Convo count: ", convo_count) # upload last set of convos @@ -481,6 +481,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): user_email = row['user_email'] messages = row['convo']['messages'] + first_message = "" # some conversations include images, so the data structure is different if isinstance(messages[0]['content'], list): @@ -493,6 +494,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): # construct metadata for multi-turn conversation for message in messages: + text = "" if message['role'] == 'user': emoji = "🙋 " else: diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 5ab9a0ea..beb87a89 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -9,9 +9,9 @@ from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.schema import Document -from ai_ta_backend.database.base_sql import BaseSQLDatabase -from ai_ta_backend.database.base_storage import BaseStorageDatabase -from ai_ta_backend.database.base_vector import BaseVectorDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase +from ai_ta_backend.database.aws import AWSStorage +from ai_ta_backend.database.qdrant import VectorDatabase from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService @@ -24,7 +24,7 @@ class RetrievalService: """ @inject - def __init__(self, vdb: BaseVectorDatabase, sqlDb: BaseSQLDatabase, aws: BaseStorageDatabase, posthog: Optional[PosthogService], + def __init__(self, vdb: VectorDatabase, sqlDb: SQLAlchemyDatabase, aws: AWSStorage, posthog: Optional[PosthogService], sentry: Optional[SentryService], nomicService: Optional[NomicService]): self.vdb = vdb self.sqlDb = sqlDb @@ -146,7 +146,7 @@ def getAll( distinct_dicts = [] for item in data: - combination = (item['s3_path'], item['readable_filename'], item['course_name'], item['url'], item['base_url']) + combination = (item.s3_path, item.readable_filename, item.course_name, item.url, item.base_url) if combination not in unique_combinations: unique_combinations.add(combination) distinct_dicts.append(item) @@ -333,14 +333,15 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, if not data: raise Exception(f"No materials found for {course_name} using {identifier_key}: {identifier_value}") data = data[0] # single record fetched - nomic_ids_to_delete = [str(data['id']) + "_" + str(i) for i in range(1, len(data['contexts']) + 1)] + contexts_list = data.contexts if isinstance(data.contexts, list) else [] + nomic_ids_to_delete = [str(data.id) + "_" + str(i) for i in range(1, len(contexts_list) + 1)] # delete from Nomic response = self.sqlDb.getProjectsMapForCourse(course_name) - data, count = response + data, count = response.data, response.count if not data: raise Exception(f"No document map found for this course: {course_name}") - project_id = data[0]['doc_map_id'] + project_id = str(data[0].doc_map_id) if self.nomicService is not None: self.nomicService.delete_from_document_map(project_id, nomic_ids_to_delete) except Exception as e: diff --git a/ai_ta_backend/service/workflow_service.py b/ai_ta_backend/service/workflow_service.py index badc38ad..0627b53a 100644 --- a/ai_ta_backend/service/workflow_service.py +++ b/ai_ta_backend/service/workflow_service.py @@ -5,7 +5,7 @@ from urllib.parse import quote import json from injector import inject -from ai_ta_backend.database.database_impl.sql.supabase import SQLDatabase +from ai_ta_backend.database.supabase import SQLDatabase class WorkflowService: diff --git a/requirements.txt b/requirements.txt index 848c10d0..eed670aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ mkdocs-material==9.4.7 itsdangerous==2.1.2 Jinja2==3.1.2 mkdocs==1.5.3 -SQLAlchemy==2.0.22 +Flask-SQLAlchemy==3.1.1 tabulate==0.9.0 typing-inspect==0.9.0 typing_extensions==4.8.0 From 60a23ce542c97e8a8f8f96de3489542c5638d050 Mon Sep 17 00:00:00 2001 From: Rohan Salvi Date: Mon, 10 Jun 2024 17:07:16 -0400 Subject: [PATCH 03/11] POI Backend build over revamp with a poi_service agent, and an endpoint. --- .env.template | 34 --- ai_ta_backend/main.py | 20 ++ ai_ta_backend/plants_of_India_demo.db | Bin 0 -> 61440 bytes ai_ta_backend/service/plants_of_India_demo.db | Bin 0 -> 61440 bytes ai_ta_backend/service/poi_agent_service.py | 223 ++++++++++++++++++ plants_of_India_demo.db | 0 6 files changed, 243 insertions(+), 34 deletions(-) delete mode 100644 .env.template create mode 100644 ai_ta_backend/plants_of_India_demo.db create mode 100644 ai_ta_backend/service/plants_of_India_demo.db create mode 100644 ai_ta_backend/service/poi_agent_service.py create mode 100644 plants_of_India_demo.db diff --git a/.env.template b/.env.template deleted file mode 100644 index b007d62b..00000000 --- a/.env.template +++ /dev/null @@ -1,34 +0,0 @@ -# Supabase SQL -SUPABASE_URL= -SUPABASE_API_KEY= -SUPABASE_READ_ONLY= -SUPABASE_JWT_SECRET= - -MATERIALS_SUPABASE_TABLE=uiuc_chatbot -SUPABASE_DOCUMENTS_TABLE=documents - -# QDRANT -QDRANT_COLLECTION_NAME=uiuc-chatbot -DEV_QDRANT_COLLECTION_NAME=dev -QDRANT_URL= -QDRANT_API_KEY= - -REFACTORED_MATERIALS_SUPABASE_TABLE= - -# AWS -S3_BUCKET_NAME=uiuc-chatbot -AWS_ACCESS_KEY_ID= -AWS_SECRET_ACCESS_KEY= - -OPENAI_API_KEY= - -NOMIC_API_KEY= -LINTRULE_SECRET= - -# Github Agent -GITHUB_APP_ID= -GITHUB_APP_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY----- - ------END RSA PRIVATE KEY-----" - -NUMEXPR_MAX_THREADS=2 diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 6c3237da..41d8edde 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -39,11 +39,14 @@ from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.retrieval_service import RetrievalService from ai_ta_backend.service.sentry_service import SentryService +from ai_ta_backend.service.poi_agent_service import generate_response ## need to add langchain-community langchain-core langchain-openai to requirements.txt # from ai_ta_backend.beam.nomic_logging import create_document_map from ai_ta_backend.service.workflow_service import WorkflowService from ai_ta_backend.extensions import db + + app = Flask(__name__) CORS(app) executor = Executor(app) @@ -220,6 +223,23 @@ def createConversationMap(service: NomicService): response.headers.add('Access-Control-Allow-Origin', '*') return response + +@app.route('/query_sql_agent', methods=['POST']) +def query_sql_agent(): + data = request.get_json() + user_input = data.get('query') + if not user_input: + return jsonify({"error": "No query provided"}), 400 + + try: + response = generate_response(user_input) + return jsonify({"response": response}), 200 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + + + @app.route('/logToConversationMap', methods=['GET']) def logToConversationMap(service: NomicService, flaskExecutor: ExecutorInterface): course_name: str = request.args.get('course_name', default='', type=str) diff --git a/ai_ta_backend/plants_of_India_demo.db b/ai_ta_backend/plants_of_India_demo.db new file mode 100644 index 0000000000000000000000000000000000000000..634e45966f8816d7e43994082988c43c40c6f893 GIT binary patch literal 61440 zcmeHw34C0~dEeUuSlkdrQZNL;cmPc;AOcG)7Qh0eC>a1DfW$!%07+34Ej{c$0FT(+ z2kk);l4WUDuH!hin>0;xIrr%_SDmX(>?DnP#Ay!Oi4!MI;@EELT#cKfZc`_9|KH4e zZ|A+;#X-r9f4=}f3B>Hq@y$2ieDhuN&C2C@Bd@FLW;U(m)tGXh;`1qQR#in&{P@2W z|Jy%4{1LPZ_?OCi{rHv0wTsQDF>r{1 zf3g^O$4kCYTYJ0jJ=gQvT1wweX_z8jLiOXLe$ zYDUlKS*`5j71!F7eJv*F!O(gkm6E@3z@{&jN*EB$O5ch#E~M9VEd7PW)#;h(Wj53vfkO-&V&D)1hZs1-z##?>F>r{1Lkt{Z;1B~}0T^g0 zo&*%l#TNW-Vt4Vokv)Uo4Qv*_>)8N)*RkXH9eOK%)h6+)W)8oC!}t|=kbLtGDxXkV zf1$O#<%2D=&7W-UZ~EG%#~Od2akOEs{#Jck-3LN{6MARu*VR5&^U<0ogTEN~)xZP( zt3HNCU%@}@;Rkz;bST{th;8igHY(QFBgBEi7f!5RUm}a6rcdo8w4|2MHJwa$XL@F} zH6xFBh(sZk*Hj~;&YIc0sjiuMEn_51bxKPm3i-UA$r)2lZGdt;wi6+@}jxD!$AL)p7w+|dSe)9D3zT*^j(RLXt zMsU{q2+oy73C^6FFtP~)204s1A1Q&(J6FZ;QkD70Z7mzoG8=^)jG<=A(5S4+A1oNp zMW)Pj8o+Mu;*6F}M%430 z0tBPuudztRq~EoF!hx^aE?CzlL;c(XWi(&B#%BPS>-I_WAm0m_q@hK6)OkITH}lvL z>N&M%)l4R}fo(05fjt+MmN&`1+0VB|L|xEQ+FC@7ol(alDrONK?^mOvV?z;jepOBC zIb$QE!KA9L?Wikx{Wi9PnW0`TYT4VMMPM(&|-x1N{ue$L z)ZuuHBQzI$K%VpQZP0uT-5_sfpLLeO9)yFH(`AGio$Ch9&u3iJs(b z77QJi;JR}RV7b^&=QEy_Se}-4wctL>dRohF<<2a^mVaAU!E)xnm$o%^6XcXr&2@Dm zlibW|>QYt%QpMiL1&&DrZyogx96sk6`21wqz~_xzK5MKM$kKntGln{yN$P1MA+R4f zDh+n|e(zxIZJxnq%Ln7!^~`Ek*R3w^m%3#6BldzG?+l+lJHQw8f?yrqtyjjJ_R9~~ z^d!7Ry1|*5UFbc(qi6fnixG7tu?cn=Sz5>O@VuS@zq~P>*+}Wxo9czDk*L}`tFPzP z@$tSh>S#O?dUr*V5S+4V=YiTI(zK2HyqIje;K5|I(~56BE!(d?YCJX@nN_cYnczT* zWFn8AHXh@k1@4oEVvn{84EK>(pj;OirUG$^id)8>gi08q9tF7y^o!qi6{@ic?wiBT zH!>OE(oEn_Gm|tk{putgi7Pq~Iuse|9RdEvM*2Cw02*)bp31um)U-Exx*|B;r@q}5cXUM(C-p$uROZ32RRb(LJE z>Y|n`s5d02ZVIsKC9sz2ym)Z7iH(PS@c%lgN7f$lPPC2nc_vzImsr3Wm_yH^eg~pS z3Je@rUWE9*sZN<%F5g3Jayg<-X3aYp;*vz<5IqEuPcB#_K4(#QNQSMy7SoMGs%<9% zpXlHOPKEmV`pS%x#B46r@)0=#D-pA(B%_R8ajR}Yy5r|GuHo!YmDg9>jj2asr zi-hiXQLfxJP%DAERpVWpwo4Ww>-G=^fucv(#L*#X8;_n|dyKH9ic>aMBMp`fVzAEc zwl+v!ju=J-V8%ktj4ba|P+j<}<|DhvTk{crP#W__plpKUG`fNz_dh{tRspGVcEXR| z2fCGRjEP_uaw^vqlh8Kai zA8mPC^Pe|;vFQh){lBx}oed-PAFBIoT`Y9H_Mw^&*J!~H1|JH1IPk3hFZ~hUH~FT& zvi1M1VJ~71EC%y47yA$l`HZ@fHPm@6v!&Twa9Bd(<(PLchYi|l*FM-_eN5_-9UUs? zgB{lPwL7iG2OVJh)hWb4;Z(sHZ{9XC3FvFlvC&a353Fw9(Qx-8r%pw{Xkw#d<9u`X z&dlOGn4Y&7VEvE`)ww8UIt60|H>^w_71|aGRgZX_FbKMW?ZJ=)>k>utO>Y@3N>z6k z!L8Pj_^4a-uGGH%NK}G& znwsRMh6GfXMQuY%LqdYk3$`@$OFup|{F#?Ou*h|i_~3-GR)-!6>Ol<9N zB?o`x)H{ghR5!@VhudpwW+9ctdCFTtkcGs@tK%P-oKWkwo>Frrgj(K=YNU2hc}F$# zn|fBw!xW+3);AD#Nd^#)wgio6Y_KmnHYC`i5EL&gU77W2L!mwidF)y*7UXhw`^h*R z$jOJp96j>iZ20Sz3t>p_fvjDw{B1 z08MBrL~x)l3B4==jdU(@9y|bsIP3g@Y$h-#ev39$nyl(6P_Y(`_6|cA9UhB*=H;&+ z9uwMVEE4)1-o-xp3AH>d&FuC`@175@*xPfThYNH<+Jl(hA{-lQT2{S~Bb^CSsLSkE zdr}(r@|(Ql-V3j(eUsEBd+|gWQ58+kWcDDcs&lawK&wKoAEx$=lmS8`Wfy_x>h#@& zo`Rr<54Zs(Iz;x9p;34vV7(d|jAE&v6wO1Dx@ZU^iaIntIyflv1XxH02I*&LWIR4< zaR=9jaadj^ZON+kggpOP4{+*8R(;|uBQdKgk-~hFO=q=RkgVn-Br%a`hiIr=!z@)T z)r3wAA}1C>M!DF?88r?y8BcVT(SE>34 zHBDXeeD8ascZ-Et%>kFokwr?zNSE#C5{~yxmkr6Ys=+r(AYJZ!Jytc?DRs$O9`eqz z&9Rj2<@lY{au68cnZjP@D0W#txaC7w_3^#zLFHV4yHT3w3GC&<2nB+)z`zda2 z?-M!`6+mW{9lUq66z+2nq1kGAy57jgTCjx&N@cVzj_x;N6 zqWbInb9wIdpo8W)X=3c+#WFf5)_bV^wH_!7Mly-h1Lq~&9JmTZ?m2#(T8F5m-ISOG zobOm92r-q;s3pQQUz8{FwhKUz1{xwGlyc=1XH`_F%Br?=iA^$V>t)QUbOcmp zSx!qY(<(K*kP*V=R3QthTh<{hOd2L`NeNztzu_?L9fPP7gL{W8?vc=^B#PX}5X}!> zkPvWv)=MCrCp>GpFU4J$u$yM3u9-38C6sXSpr_ab#8?OnlLpRYSWk5LW>VS?_~44M zwPoO&)hsb~d?1~7wryz=tR%(#x#(fDfEW{qzgl_?(FAA!K{$hd<~0*?3;FtXAgAE3 z9W4WAF+>}ETT5YpOSO^iBLlOv9MIW@Wva8s6+Jcu+o$D;44(JWK^t?#X&<&7Y&Cl@ z3{8P)_sejClv%y8I3x-n4^-;Pg^aQCH11T=&c1;91X_X)N(A6U%y0 z&)`mD4hPL=Uj8=X7JMR2f%7uR3sb=Tx$gGU0{1sLTcAQ!+}{%28W31^#lBRp8vfx= z1Ln7}B=!=(uhwsUFvQmc+g`vemb9$XrobDfr1@N(^paQ9mmZn5zZqVH&u?{BsLSQ0lOo~u2c;}(GmCj-SwONZ4)w(FTEK+ia@Zg zzPJ_mFc;^G2?SIy7?i?KESF5Vm{6)B0>PYJ#Ume_Qm?aXB7Ii$gPz?__yg$-yM;gO zK0?z2ZIPPw5D^PYs*vck?L8*Pry`5VrMEc+#4mUnibmtK|q!@_uGHt z(Qye6i7_vMzdpODVN4q2a@;$}z3iffxYT9w_-NTKdZlZhOq5$B`c+$pUxWm^p=A*+ zWCao-rW4!4Wg@;9S&6953L_eBc+jQW@DKVCNS)2X6Nz26X%9C_qrV zvNFBUug=V@r~o0WZ6g#V56kWqagE&|xrsbT=k?U4p~Ak9N~vo)Vk5TA9PIzRH`v)I z7WGb6+q5*OhEW$@Mu6)Af|C|FU7zNn5*4b_b;g>$H#q5p6PrXv;oXVb)kg_&mAlxJ z!oi*(=z`AS_>@WfK%+F=KO*6U^#69{rV75k z`Oy19^|jwp^P4rv;D>|0z>oU>!vC`G6Fv=1{_%cTUlNN6)+-_yXDoS6m{h8;UOlC+ zkxjq4MFkdMUWIB*BBun-V6M<^?m)feC-dT~@7J?D6nGet-D#P}N^S4&NlFMVB)ogP zP21;oBzicbE*H|`SqCBI)3fzlexwOAk1r zNn^3cpY`&_FiUtw+sj@J3y&0lO{@_rza2RzDgLD(I_T!M@_z+#Y1R(jw$*=2!RJtu z!2qW4Z8*m){r$+Z(hM#?-=Fi^0t|zJT=-d?heEd;~)BUzOVLdb;Pf} zMV`kSZekuVvx;EBPAfPMtBZM5<_R`DDxX7T>BO$XtH8{fLuwTb0RYi4=(gMwnB(B8 zu8K#xZc3n~pZ2a=n;k8>m%G#=3Y@yDn}rNg*lbAJ%ax@S>!>^Ov@{$$e*@z=+U`v& zzFP>UmGzgbPRWAu8dM~VA}ykroA_;45_PWX*-a4=L6<+YJj{)x#prY?ivFC{v}_8R z&CnqCWegAD*cu)hjnsaD_g1OPBX7JRO(Xr37i*n!p1ryYX8^yp;%cWMOGp-4F5?J> zdEZmgaO}S8<;>y#yLaTi>k>$ppY&qjUUvI^PfA^~C$D*D*;YL1L2ES~yJvM1QQD=1 zP%13US=@O8O+o=Spj?6ZiiN!c10IjH0Oi8^T~a}~>SZ7%%}f>&klgU6$!%h|EJt_i z;9!NwnMA+`H-?e5YyqU1Ad4Wb3q|AOV*}&k)=j$muE|q-_^KfLQ+zrx!2qaG^~B3< zZ)$p$q=IGhZZ)!ZD)BlL3K*UEp^%Q(4`)_c7+UM9G}z@8?_l>rv#l#qm+aDVIn5sR zV5}O=N?~Z13p#j{u#s~o|6n9OqHd+t#cM>klV*x^5)Bsg=&)@O7#Fc7@X=1C=w=o) zPI!AE4(h2SGMlMmW6_wfCWx@0#g%K8-+yu$UfYGq35!FwF3U3-x(qDA!ZgrDXKsh! z&{U{Ssz&NIJmt{x)ji@+a!yUBq7{8I!qsG)T((Q*mt~4AZ`S(IWeMo5CGW}}n2g-A zlgGZ)vLp?5`I2|Adm*EiOH!A`$YL28y|UkP*wT#M6X3oXTrx4UQUKG6PL8*fmbNeC zOl@julyToCvw?`GN!+)9hO6f(=Q_ncXjz@yIob}aBsmRXy7(F@F{jXxA%2f%XqdjI zFcjF=!BH0@S4r*0y`7wv1VJuE7T|!>)(W^O85JfgL5`*1li~64v9gF1d_ZIfar9pi!05%RGm<9Em0E~5Q_5>fW;B8teGcN7vQ7>TX3HM zQS<3o^P!2scS+<~ZZOJuMtBVa?4ogt%`yT#?~mkN^7@V4(y<^R`}Vw-;J<94T+Q>+ zxTOETK{=ms=8qm-7TM1~+!+2RyI*F^ca&MKued&cVH0hp`U>^z`He`i}H(MWbqUS9Q?AQwu z0GFTl0&p*k>Dcp9muz^aj4`pB_N?|Wd^^D6%2SGNyiu^f#aS}Q^(p*~I!|JNJ(b+y zaj@b-pmPBLvHz~%)*1T@BcmK{1K4hoV>>d~ckTie+``b_HPC%_hBJo?!c9_H;X06o zAy-(-ZAYHhcW`1uL=#2TDdP}c5w~>DLo}@5su~}MT!!caa!^JGrx0*?zF$qH=J3aY z+h=8iRegwK19#<_Z{NYpP1wjU5!tt$;4`N}Rj%lOt@f4x?>j(rsnfxp?GdVOW}-t} z>l%u|7w|EOf9=1UWBqsJDc9fj(r731*)(d!njTDQg~+u4?!KYr5+GBMrWHb?Ltle4 zOr2OK$y{h^M{mop%oV(_gxL!oSoY!qjzn7}j;&}i_7zgo_r=qkCP9OK5qR0;vunoMHBip_iNQpim#2`FGr2C*45UVmoT{(%#E`1%XJ(ZQG_NM1BwX+b=d0<7i z6^*-FU`2vxK}bB4*#ff4(2UE4mr^yW7I}PUUVVQi5*>%T7vjdV(suRMK3en0b1pEq z1(+wfr`ifFq(apd*fO$PK!1q|RR&X3E7igYEke7d0PmS$_@u8&V0vpG4fbtIU>1*> z0F!%4+Xv3tDIIQ}-UZB(!C+sJ4T5v09d^Rg?c^+e?)|uT0S5kqrg@0NFj*zeVZXy1 z4fL84R#nK zWJ}xGpON8-rm;$hKp*EG11k{`6{-@G7qh!KSG-ke4ZuFrpej4(&=71W@^-D%ZM$MfR2eR zEr%Pv$b%E&48bxMOXR(Kk4GO)Nz=W%h1t52%Cv3R1hdD5XRg4WQE?TI|E0n!4u7mh zIi#S9f{G}dnL?gQ)UhRv%ue0H&`y-!Ig!xGit?*$Q+9B)>vcP)mO!L+`MvDQU5pHD z^1XLYyQ3#>$Yaq-aKfMVCFiraYvq!8ELG_K zO=(t_H@wt%FRIY}8&a2Savgm-=k5(Fn^Wa;mur-`F^_xnkgYNZ)k6v`GgCUyLe^#6N%!=Tr3V(kBFzQ+}HhqH?7N)ctpn} z9RJJNeMB2_`?;CZRNOFgO@Y&rWmC5+i8|OR3Hf%Pm0}03sS#Z4lliE--TnR(x&$og z|Lw{rl~%L0x#ec_=bLAnezxhs#_@({U;%igzNzj^=zBwpwI8j$Uh|#7zYcZ>zADh? z|90P((Bv!qr>HW@4VLTbo8lM$=D7JPOa-5h2lDxq6HachXXI>JsxzPlA9*`mB>lxvX2 zpkI0x;Z4!8IE6bym0Nc8?JBrIt}>9{?;lmZL0W#+!HyAYa@la$(J-woWOzDMfFy@i zIy1Htl~wRYaJ??2-7&aMS5kw>#?CKEcoaO>Ec|}uU&`Qx*im9KGI-&H2ReswI?%bY znR$KgmhMI@@eZz)mH@mN@q+U85-2wq&vhn4(%$ZY#7XktKq69|ah6op`@fKSXAiI= z-gUsWB~Sm3lf&s>q^n%qfxT2XHWu>g0{*&X2oukF?e>m39JQ{;rYN7EItTMVrL0`g zQ@3%ORqrctM}Gf^^0o4W9$~HJ zsE%8x=80o4qbjI|@g0eA++!uQ^>sb#q>JB^L zpUcpWuoj>{))fwOd10cBqn`@kVHWzG%sQ@~1OS!s?yATEF9K$>A!8nFU9%YSr2HWN?Zc5*D27k9S_zI&GXfAfu-tL+7?o7g-($`46 zvin&>ImL=ZizQ@Mp;!Sfti!8?yLn2qaJ>L6kgga(g4lou$`qVpimnjPnj1WGAh|T1 zoLI2_+EBeqhH8=3W3mff?E|p;D9Siz3~f{(pgfF$_kcJ7oNy^Khf|L71X7mQX@2b`l6MiOor`#tuadx6W|ZBmGVR(> z?1y%hcSxPGtE{Gsbc-S$nMY-D`JxKxiV&k|B=&$AV=jeM5%|ZXYZ8Wa3eR8?)i@)Y&b4$ z4EIQW|B&)_d4dly%C9!kWnEl8X$91X`H^u2$E0U|8WLvfe0uW;{i--UrwY{|YoaQ= ze<_~m7xvFl2|`z8e^hxzhOdG7%Sg-8^05Hc5mSY<9Igm}iFm}#JH9ytJEXiHDIMN2 zD~+n1AaH61IYu~yex)cwbAtJRt!S5(j3^f7P9aK0M8H(!hB`7{BW#jJs*EdML3}@* zyo0c)TsS8b1JYQB;NVJg`Dg^$x?>XBq;~#6g-Kv;F~z&d+eGjYWVzi@Cgd1C3UX8f zaz~625L!riSsINEyoAwCcD1zu4T=&w$M1$xfkp#PVBvIR7twPiZ18V(L4yj5F2qFI zvg2xR7}Yc3K464r)Pr@4l0RCh^>|lOfBw@iN$@`Vwld_53qCBe>B$CJhq$>2E|9}N zE1^aBuKaJ4Mqn*3VhBX~ba9NtMKS;?K$x>I;$#(ch9H#?5zS87`9plZlrW^@#(^K0Yc~JsL`2P*1^+wCzx0ua;)bzO~rSZEOI~rc7FVwxS?p)|+ zLg#Al)_g;YiS-YFrM|yC9$pW8IgrYp!=Z43oDU> z@IJfr9KXLt`5|d^_8{X4(&XWhQN=UoK*{PYX!Ka*IkqWE_#KjK&bVI zUhc_k<1z<Q=XxkDj?uSKp^)6 zM4nN;PoClgqaxhAhI?c`PeCfz*6!$D3JEITBaO`bjPixK=rDZE zGxWW29+^zbSWHm)ZV8H~8PESE^FPNXYO~)1v6!IpT~hb#3RBCD(_*7%mQ^Htc^a^O zEbHRUYsU(01@S=VgJ_i_P63MuGN(Y>q_BPvJ#iUIQ@LO!Z)wR}xZbb;iV^BTiD9X; zvB5FxM!5*OU0Q~U;G!bGKd5}CG_j(eQR=acuD%JuJB3Y^3VRYiviSmB9nS~dJO7mX zC}(;VdFU_#;rHUByLSsDj49tCVSzowctSFnd^#Q9gH1mNJ-&rq|K5BL0p;6e`24K% zA7hUNly8%uc$z)r1>&{sv4HYEse6`S9p!t>Inj0(S@*rippd|~3MtUaL46xHxS?Fb zBSe0Dic7VlJZ=aa9*&R5w=OE}xPbDl@-z$V zL9ye`jnIiIt_Oyxu;-4;5bz?hYdJ#N2U#v3#`8qkBtrl9cea-LsMG~~eegUQl)hnT z5PI8pxVVR_Ko3S(&{XQl?+?V4Z;>b7$mo4B*Fn*l^RUo^?lB(j$hVjs*+@%i1#G&& zg~L(#W*MA%MlXZ8zzORfaP~O2%0f30xNSN0tT;>?sFIa^sEJp_hk$6uOuKz|L;(il-AXj-)niP`Ag03X$~}fW8?2P1{&Jx zf3Cix?uSEP485`T#oC8q1BeG-3Vf^obN;;Vr+xFv=h3X}2M*;~FFy%0%U9v(WP7YI z&+~w+5boR39zP6{Ls04moL`i~p6E|_g>WmUu!c}uF7x_>0p*va=|06~%9dPQeq zGm(Hhn+LNYP}#wl@*kuj*&w4-X{TYHSbK<@A=pC zpvSTX8@&kKz`X=h&#Yx5bTc$Q5FbOBrWcwEH5D_w?@zF7Qp6{Z$nNg>a zfRtajrMN~g9od62wQ1(`3{{;@Z0Q*UJEAq5hZETho!JxFH4~^96{)V{$bm*qlxS&P zeo3C=Fr(CL=epX(yvBu-feJK5j~Xbjk1NCa1`-F-Tb76pMoJX^IR)a#%Oh9gwPl2{WAtA ziO13KtR`<+Rv*)zQa&UDc%403z6AvfDJ^uhBb+X;15AElnm??F5K?+7R1*gkIB^0_ zS$g>_Yit-uoQ%sNxVwe48t~N8wWT3Saf!`^#4l+$sqmDnhFYNENeF?rDJ4vFbbQQu zfp*1tR@3%pa_40q1@t@w_eENKf1@a%g!rRa`qjM z-6`94ard_Kw8LslDnBOyoMLCZRM}WHph-5oVNsC&71}@f>9`ld$$t3A$>VvVtt5{9CDSw#g_r znbvi3Bu+BkDM1@lpySeWUJ_Yzehk0%I$UvhQv;b{$OS;rTZ(9-w3V0?OZ;K@ImaGn}Aajl+=UqW~!lEvM6RJIwYwzfO z)s9d=3|wp&Z+Hw_Aqv(GSAJTaSe#MPv#GB3K5j-nB_b54KonB(aJBKz;}h)F|G3ok9mb#DlDlrZZg;)A zmO(&sp?L~1G&q|#^J*+QZYBE2=y(Bja-q+S)E|>ZX1y!|tj4hKxNk!otldgp%@>{=j3(k4mtT{@-31P47m%HwnB(`2(q2cA7oyop+n{m{SSw zRn9&EFq|v)ACtIXm_sVKl)8$5m3}o&52Nq{YFrqqV5XWfLADxQI@3VJvBCJ5%aXe= zVdWlJyzUr3YLwrXL7rxm-)yd{eT47kE|EQl3IuV*Q!S$xa5iu5K&vH-!TAW(Ka!VA z8ic&l+_g)UohOBdtxP(Q&XJE32HOww=Aw^kt^L88OUfsuiQHySd6(3Vkam=+gHpUV zLYjl*4+fRrlLlswGM?*9-fLlI1xY(Y??LRAU|VL4B@xAbK6oz^-oEADpz^yC{7*5S z7)=J>5q#~fTJ6wSfkPqXHsS@9Pe`4!`xsA=CiiYXAOrgHf%8>2MT;AUP6?v&Buea%Ex6GB8<|VIbpfhqmVo* zw;HNa+8z4}A^CRrlpV=V73eJ0p(;Gv%i9Iz7q&kXFcxG6K4G&XSURAg(d46##(BnCFyZ8r&ZD@pG@wz&ZKqjpOWzf2(`(yMN%>Vs#7bhVD5Hl# zx%Ri|h=M5wr_lt`cHttghAlgWN0e~46c^O{dK^*L?O1<--6K)U?zJ&vm9U^shBfEVK1iS<;Dj$*X!j7>^YeU_!XyuUf(gkXCV9Jc&!WVxuAfdI}bRv-|@HK@|6%F<~Vii#yciSAIi= zYnd%#y&&eC;@4|Lylfj4s#2GO0&5u9CIqf4GYET~JPWfs<#CS>v<#WDQdA4jtC{$`!UM-l3G}b5{;*S0M(B@N7KBr9~06KRkpZ zVpQg!az!9*R{2#4x?(e%E5q3(ViCR0g9{( zGG{ocT5ciucwo%2tyr;Gd|2%zd)53Zr2L$*fnWr zMVwK7MTYqpyWm|!1UuP;Z!O}1#C_NtEe9WYzpL4fInKEehyzOw@efOIk^bMVyjN+R zM*hEe^N%zSHGOO2aKm#AHT9?K)m5^Lq*N zulLV3qnvUsKT~*`d$_R+y$1K04Bl%dIn8z5dk&{QU2`mHA*DQ{)S{3MDF0hRzzU;e zZmuEQ+dV@%%7indxYs>TX(u9fum4NxmA#qWMz?caZGDjAa9^66UfZca=5vp1u)1kN z05OpjS%aMzmUE<<#dD5*$OUl+vad5Lujj7pt-D&tCMeff#=K)_G5*>xJqP&_0p#v& z1J{)=$a839Jkgq4JgfKM#<2~QpxoKk&v4f~Y#KU1p>$$+SD2jjIz3clC7uy(X>MH` zijP1fx`Gk9T#P?jV$2;Y_-n3PciZ!=$2cn(rWdf@C3>mWSKvwIZ>5>sVmU9#wOz2# zT4vm)J~@?3_l&G@HZpqJ%t+R4@p(Y`8)-0B$0$u(S66#Gw{FKyQse_KSeLM%&N|j@ ztLJ&X0(@BUiTDizw{yQp?WIeY(V;#Vxko84?Pv_{ywAC-2d*i9Ey1$JD6N`nueF^O zEKppqk~M4Usq)078FQW%Q2t69guRhbRt>U3o%PMuktUM(d~g*W(IR8E_Uho_M*0gh!#q`Ho@g?ShwD7UJBO!>S7 z$qkk%1H|jy-p`q2pVJ}bFQwku(=6>h^TI{zKD;Bi?9a~R@iMp)9XU^CoK-zBVi0kR zwKW|gjPoEX68soY{qT_;5)O{T^+;H&s)Up|HrbH!-=+CI#VGrjYZ+gQ9V(>!h15Bl zW|Zv<`dXN?yXoWpagZ?;s>=VH5~%4*GwX0vd4lR1$pL9r}Zz;1%W1<%wKilsjy$>tyT%@omeMOoa*) zk$u73dk4!|Fs#c!=yyqj*~*bei|LM|;f5zoPu9 zG%JH`dI`C0TJUAqF)&S;FFQc<0|~kq#AAAD12QOo+m_CGJP2HhIteL%B8|l!XOuT= z8g$)97__2Jr~nLlB_lobJg+#t!sJ9BYkP8w42|0r^2L#q5nJi9^aGP3@eJzO1x zu?taQ(AxyM7KpSa(uh5V?!@CQR3769({L0iHYmx$2qGI1ffTQx5Q@-1(g(BDd4=#m zXi-*d;9%+a_-_)L*-=J`)J`Ef%=XKJUfQTYSoR2fV)^-Iq1dE~#38=j#ZAv!$vw?S zu=zn?SHi9*80DLbcZjMOUsRT>(R3j8%OSt1Xc_D1#U zM-B>H`jO`#-Rr<~j=oyuvl1k$jGh;BA>rQKk+sTaq;6S^X=OW$UGvP)Q8}C$a#qx{ zRG@Q=d`PH9CBptU4&mbo+c|$;q+>wEIY><2kfviXxE1Z)MBOu`8_>9Hf z>Ch+3U6QrRr)99?>{*}#QLHiU@V8RoP@y`dQqT}Cn2C!eSMZ3IGaEsG;x}FbD^;Uf zD?4!Jz_gO507&%)KGHHb5otrA*hAwm%x&?8Ra*OlwM)t$NuZPd|A6unO6#L7-_UZr z`D)V-HU3@WcQ+nyNY%fkZlmrS~fQ*80lUYlmA zPi3c-BR$)v)rAaR7PpjAmHYw;z65xLfZifE@|)c}Vv^dyR8ZqEX!Ci;b?5CNPiuV$r|RpYg6BP=`tH=`muL2Q`J{i1s1Mu8gL z*DS+(r`QB|2f8}KM~-)&I(1E`%yV?smZxa4`{Y_BSSXbx70>~ce=lB>F3pYG`b))L zK8ZeC>E1<|M&ce>o2Jy#Zf$x)z~dP}^^6dxN&1@PsVl{@R9xYNO(kAd=}gtdD)|fU z_Nx#OYeb?~Vl8$y)rB1`lh(2~ZG%P7*C@fTQf%-NYP-mULfFS9f>>3}w6vPl^-`}5 zQm^b{k@9X`0uge@fRLx#seo?rA9cI3xI}m6G#xX8jX+(40AXcG6@>0c)Xvncnq>_< z^~=TS_&6lb1ocx$@6|dG^e24v@@!^{b!8mC=(88fD}2%@C8L0a{(eepJUcOohv@J? zUO|tjPh@pOS%RH&;e+@+*gEcrP^>i_2lR*zUJQq8mV!^t;BFNVq-F8UEF6C$B)d`# znh<8ewXH69-pu_|+YVzz;xuTGxJV387e!9o?~@QZim+o^zMslpzdm4Twe2duX5 zq>7>*p1c=g87ad^_<}TjP}m-G7J{52WsL&CHD5@E`Gq2-{SqwV#K3^XA}(-4WtosA z`c^j%z#=3F_!gGR4rT3aa`SSDGMk$P(q2OZf|{>ZntrjRSeBnGTzt_czkTwFOj0Ai z@x$B+f;)R9Yr_Qg9e)i8BD@9<2;NfWT!olO#@H@o$@s8AX=isdo73!;zS*JpCFGSg zwxIy%xCOAfsqu^;{pwZRV9#i2JW(N@X*0Akn}~ho!(mzhUyX#3m16L9VOjxSQ0kTC zij+*uK8ZR&X1E90KEsW-)wae#$Sle8HbbvTa}p>O4@_GyV}{=4vU&)r7r$-VuRKa1 z@T@N&O}p4oEYDH)N{xrS9C?<8C~-I0b;ph8oG&?!I?@i&fVjIzS`F!zD{{+~9}EU> zgXTt$zkzMU@ehAFxZIUjx<>+0dFIuh&gOFix5->)+{$pHZy#MOLll02ydR)Go_KEp z=>gC^JxDwbDO4aU5b*gWY_O9>%8|y2fJ8_bPhbRU50gUon#GYSmT7L3=8Lt#aU}o~ zzpb^Lj|2kG_gR;BO7*DI9bIAs}6hjAB_H@Zm@)NmZ! zzgYlP)-K?`q5K~i%r~0`ljcDBf1C1kO6y3=r&^Yqzoz+E(?a7H5CPCz|6<)A)b)ja zwDuQjJ|6tz;Cln_^}pTsDPO|Zgz8_upSwj$edS_z;T42wiKCdc4C+(y< z-wA1K)?K7LQR7e+xn)X3Hnmfs!cMZtCaod);x-;=PUF^I_Idi$WFQ@7AY^*>#zeu zzAmX-mMXpx{fnVH zav)Tb@x4)=?|P9kS50(5)+SCJ@APY!FBPiKH=}RRB?Qde>A1?ycWOEnd4P;J9s9^2 zr89!am*lfVRzMwd9NzyN%f|_Xf@xo;JkzZrrN?qCvN(cwFjLDSTV-X};Y@=A3=c2w z$~sEbIP2K;=miIO#3d=(o^m*8={Z()Vy%$JJv0N((FFWcChqHz1~xwE;**{s1EE05 z_mDK*yTuOg{wSHgy`jp^>T%QemCP1m!ZIq7=|cX>!_)-}7aX0}vh=bvvZM>s)}#}B z74rK!q@meh@j-!U@h0L~p1PPnkXW6jws37jgyhcP4oZrwk(G`Tr<$l$<>`-cAo&Fs zx;*UyA^(=|K^f+UitS~_Mb>At)7{TrS^X{`Z@WpxLj?Lz4u%9`S+Obe26oFuT&RG? zuwbSNiG+r?WO3ubM5t(n{)bmKf$a1mItc!Rkgr{aFI=Q-SPu0fR0+G)rrs(B!xD`g zgnB&@OcWDn88HP$)l#>d+5@3mzBfpq7VCQkay4?o}gbZJsgbVgqu`IDynBBC=>dSFT+eSmLt+2w;O3JHa z8+6XuTHgcG3|5QuGS=&G&er;lOWm@m;xY6)hjRzAn;q)YP6hCY3iM2>ti23~d*wv3 zhvw^4ch}yR@>`{Cpmx)DOqxc%c+`uXwi}++tzxs2)`R#0XR>@&(pA5ChsDp(Cl+tq?o)Nc6hm#0x*JmN*)L5*ko z_PjY3Z->xpd`Bc;R*Lj^m7CD-#RXU6yHDzs%@kYF?E7~KkjRQ6QbziGA71IA3>R(V=A8vVf^RG4EZu;5AziPbR z@GT8f`0()O5Cex8IK;pq1`aWBh=D^49Ae=0!~i>AJngkkgqh({H1=>aSFq^Xx}qIm zB8;?^46~D)aKj(U^yu`gX>P$9j0DkSsl)4>OMPcci-lU`y+Dwf6{A+VMcLr%m*JW& z($i1SGEZ?g`ZW>5N(G>`ik2zn3}csNc9B&U$t7f0JD72XUs4^WN7mi;8BcG4hMez| zJdv^DBjwxY@|S!2+~bz1_dO!PzFO>ieYQ-!uTSb0TZZ0zI>?rRrU@3XuPtL$y@xGR zrDdQ#>+6-lf1nulu4o(U^sH!=Ed$Go0eg$Z{AknY8Q6z&o(*FK$O*#?!tOG-KO=7` z%gYj|_l0H1t`_Opr_!dG=7!qcH;tI}o}0$5b}yS|a8x7;7EYy+&%FHg$j#&4MV{^h M_0RbpmS*<<0AF|jiU0rr literal 0 HcmV?d00001 diff --git a/ai_ta_backend/service/plants_of_India_demo.db b/ai_ta_backend/service/plants_of_India_demo.db new file mode 100644 index 0000000000000000000000000000000000000000..634e45966f8816d7e43994082988c43c40c6f893 GIT binary patch literal 61440 zcmeHw34C0~dEeUuSlkdrQZNL;cmPc;AOcG)7Qh0eC>a1DfW$!%07+34Ej{c$0FT(+ z2kk);l4WUDuH!hin>0;xIrr%_SDmX(>?DnP#Ay!Oi4!MI;@EELT#cKfZc`_9|KH4e zZ|A+;#X-r9f4=}f3B>Hq@y$2ieDhuN&C2C@Bd@FLW;U(m)tGXh;`1qQR#in&{P@2W z|Jy%4{1LPZ_?OCi{rHv0wTsQDF>r{1 zf3g^O$4kCYTYJ0jJ=gQvT1wweX_z8jLiOXLe$ zYDUlKS*`5j71!F7eJv*F!O(gkm6E@3z@{&jN*EB$O5ch#E~M9VEd7PW)#;h(Wj53vfkO-&V&D)1hZs1-z##?>F>r{1Lkt{Z;1B~}0T^g0 zo&*%l#TNW-Vt4Vokv)Uo4Qv*_>)8N)*RkXH9eOK%)h6+)W)8oC!}t|=kbLtGDxXkV zf1$O#<%2D=&7W-UZ~EG%#~Od2akOEs{#Jck-3LN{6MARu*VR5&^U<0ogTEN~)xZP( zt3HNCU%@}@;Rkz;bST{th;8igHY(QFBgBEi7f!5RUm}a6rcdo8w4|2MHJwa$XL@F} zH6xFBh(sZk*Hj~;&YIc0sjiuMEn_51bxKPm3i-UA$r)2lZGdt;wi6+@}jxD!$AL)p7w+|dSe)9D3zT*^j(RLXt zMsU{q2+oy73C^6FFtP~)204s1A1Q&(J6FZ;QkD70Z7mzoG8=^)jG<=A(5S4+A1oNp zMW)Pj8o+Mu;*6F}M%430 z0tBPuudztRq~EoF!hx^aE?CzlL;c(XWi(&B#%BPS>-I_WAm0m_q@hK6)OkITH}lvL z>N&M%)l4R}fo(05fjt+MmN&`1+0VB|L|xEQ+FC@7ol(alDrONK?^mOvV?z;jepOBC zIb$QE!KA9L?Wikx{Wi9PnW0`TYT4VMMPM(&|-x1N{ue$L z)ZuuHBQzI$K%VpQZP0uT-5_sfpLLeO9)yFH(`AGio$Ch9&u3iJs(b z77QJi;JR}RV7b^&=QEy_Se}-4wctL>dRohF<<2a^mVaAU!E)xnm$o%^6XcXr&2@Dm zlibW|>QYt%QpMiL1&&DrZyogx96sk6`21wqz~_xzK5MKM$kKntGln{yN$P1MA+R4f zDh+n|e(zxIZJxnq%Ln7!^~`Ek*R3w^m%3#6BldzG?+l+lJHQw8f?yrqtyjjJ_R9~~ z^d!7Ry1|*5UFbc(qi6fnixG7tu?cn=Sz5>O@VuS@zq~P>*+}Wxo9czDk*L}`tFPzP z@$tSh>S#O?dUr*V5S+4V=YiTI(zK2HyqIje;K5|I(~56BE!(d?YCJX@nN_cYnczT* zWFn8AHXh@k1@4oEVvn{84EK>(pj;OirUG$^id)8>gi08q9tF7y^o!qi6{@ic?wiBT zH!>OE(oEn_Gm|tk{putgi7Pq~Iuse|9RdEvM*2Cw02*)bp31um)U-Exx*|B;r@q}5cXUM(C-p$uROZ32RRb(LJE z>Y|n`s5d02ZVIsKC9sz2ym)Z7iH(PS@c%lgN7f$lPPC2nc_vzImsr3Wm_yH^eg~pS z3Je@rUWE9*sZN<%F5g3Jayg<-X3aYp;*vz<5IqEuPcB#_K4(#QNQSMy7SoMGs%<9% zpXlHOPKEmV`pS%x#B46r@)0=#D-pA(B%_R8ajR}Yy5r|GuHo!YmDg9>jj2asr zi-hiXQLfxJP%DAERpVWpwo4Ww>-G=^fucv(#L*#X8;_n|dyKH9ic>aMBMp`fVzAEc zwl+v!ju=J-V8%ktj4ba|P+j<}<|DhvTk{crP#W__plpKUG`fNz_dh{tRspGVcEXR| z2fCGRjEP_uaw^vqlh8Kai zA8mPC^Pe|;vFQh){lBx}oed-PAFBIoT`Y9H_Mw^&*J!~H1|JH1IPk3hFZ~hUH~FT& zvi1M1VJ~71EC%y47yA$l`HZ@fHPm@6v!&Twa9Bd(<(PLchYi|l*FM-_eN5_-9UUs? zgB{lPwL7iG2OVJh)hWb4;Z(sHZ{9XC3FvFlvC&a353Fw9(Qx-8r%pw{Xkw#d<9u`X z&dlOGn4Y&7VEvE`)ww8UIt60|H>^w_71|aGRgZX_FbKMW?ZJ=)>k>utO>Y@3N>z6k z!L8Pj_^4a-uGGH%NK}G& znwsRMh6GfXMQuY%LqdYk3$`@$OFup|{F#?Ou*h|i_~3-GR)-!6>Ol<9N zB?o`x)H{ghR5!@VhudpwW+9ctdCFTtkcGs@tK%P-oKWkwo>Frrgj(K=YNU2hc}F$# zn|fBw!xW+3);AD#Nd^#)wgio6Y_KmnHYC`i5EL&gU77W2L!mwidF)y*7UXhw`^h*R z$jOJp96j>iZ20Sz3t>p_fvjDw{B1 z08MBrL~x)l3B4==jdU(@9y|bsIP3g@Y$h-#ev39$nyl(6P_Y(`_6|cA9UhB*=H;&+ z9uwMVEE4)1-o-xp3AH>d&FuC`@175@*xPfThYNH<+Jl(hA{-lQT2{S~Bb^CSsLSkE zdr}(r@|(Ql-V3j(eUsEBd+|gWQ58+kWcDDcs&lawK&wKoAEx$=lmS8`Wfy_x>h#@& zo`Rr<54Zs(Iz;x9p;34vV7(d|jAE&v6wO1Dx@ZU^iaIntIyflv1XxH02I*&LWIR4< zaR=9jaadj^ZON+kggpOP4{+*8R(;|uBQdKgk-~hFO=q=RkgVn-Br%a`hiIr=!z@)T z)r3wAA}1C>M!DF?88r?y8BcVT(SE>34 zHBDXeeD8ascZ-Et%>kFokwr?zNSE#C5{~yxmkr6Ys=+r(AYJZ!Jytc?DRs$O9`eqz z&9Rj2<@lY{au68cnZjP@D0W#txaC7w_3^#zLFHV4yHT3w3GC&<2nB+)z`zda2 z?-M!`6+mW{9lUq66z+2nq1kGAy57jgTCjx&N@cVzj_x;N6 zqWbInb9wIdpo8W)X=3c+#WFf5)_bV^wH_!7Mly-h1Lq~&9JmTZ?m2#(T8F5m-ISOG zobOm92r-q;s3pQQUz8{FwhKUz1{xwGlyc=1XH`_F%Br?=iA^$V>t)QUbOcmp zSx!qY(<(K*kP*V=R3QthTh<{hOd2L`NeNztzu_?L9fPP7gL{W8?vc=^B#PX}5X}!> zkPvWv)=MCrCp>GpFU4J$u$yM3u9-38C6sXSpr_ab#8?OnlLpRYSWk5LW>VS?_~44M zwPoO&)hsb~d?1~7wryz=tR%(#x#(fDfEW{qzgl_?(FAA!K{$hd<~0*?3;FtXAgAE3 z9W4WAF+>}ETT5YpOSO^iBLlOv9MIW@Wva8s6+Jcu+o$D;44(JWK^t?#X&<&7Y&Cl@ z3{8P)_sejClv%y8I3x-n4^-;Pg^aQCH11T=&c1;91X_X)N(A6U%y0 z&)`mD4hPL=Uj8=X7JMR2f%7uR3sb=Tx$gGU0{1sLTcAQ!+}{%28W31^#lBRp8vfx= z1Ln7}B=!=(uhwsUFvQmc+g`vemb9$XrobDfr1@N(^paQ9mmZn5zZqVH&u?{BsLSQ0lOo~u2c;}(GmCj-SwONZ4)w(FTEK+ia@Zg zzPJ_mFc;^G2?SIy7?i?KESF5Vm{6)B0>PYJ#Ume_Qm?aXB7Ii$gPz?__yg$-yM;gO zK0?z2ZIPPw5D^PYs*vck?L8*Pry`5VrMEc+#4mUnibmtK|q!@_uGHt z(Qye6i7_vMzdpODVN4q2a@;$}z3iffxYT9w_-NTKdZlZhOq5$B`c+$pUxWm^p=A*+ zWCao-rW4!4Wg@;9S&6953L_eBc+jQW@DKVCNS)2X6Nz26X%9C_qrV zvNFBUug=V@r~o0WZ6g#V56kWqagE&|xrsbT=k?U4p~Ak9N~vo)Vk5TA9PIzRH`v)I z7WGb6+q5*OhEW$@Mu6)Af|C|FU7zNn5*4b_b;g>$H#q5p6PrXv;oXVb)kg_&mAlxJ z!oi*(=z`AS_>@WfK%+F=KO*6U^#69{rV75k z`Oy19^|jwp^P4rv;D>|0z>oU>!vC`G6Fv=1{_%cTUlNN6)+-_yXDoS6m{h8;UOlC+ zkxjq4MFkdMUWIB*BBun-V6M<^?m)feC-dT~@7J?D6nGet-D#P}N^S4&NlFMVB)ogP zP21;oBzicbE*H|`SqCBI)3fzlexwOAk1r zNn^3cpY`&_FiUtw+sj@J3y&0lO{@_rza2RzDgLD(I_T!M@_z+#Y1R(jw$*=2!RJtu z!2qW4Z8*m){r$+Z(hM#?-=Fi^0t|zJT=-d?heEd;~)BUzOVLdb;Pf} zMV`kSZekuVvx;EBPAfPMtBZM5<_R`DDxX7T>BO$XtH8{fLuwTb0RYi4=(gMwnB(B8 zu8K#xZc3n~pZ2a=n;k8>m%G#=3Y@yDn}rNg*lbAJ%ax@S>!>^Ov@{$$e*@z=+U`v& zzFP>UmGzgbPRWAu8dM~VA}ykroA_;45_PWX*-a4=L6<+YJj{)x#prY?ivFC{v}_8R z&CnqCWegAD*cu)hjnsaD_g1OPBX7JRO(Xr37i*n!p1ryYX8^yp;%cWMOGp-4F5?J> zdEZmgaO}S8<;>y#yLaTi>k>$ppY&qjUUvI^PfA^~C$D*D*;YL1L2ES~yJvM1QQD=1 zP%13US=@O8O+o=Spj?6ZiiN!c10IjH0Oi8^T~a}~>SZ7%%}f>&klgU6$!%h|EJt_i z;9!NwnMA+`H-?e5YyqU1Ad4Wb3q|AOV*}&k)=j$muE|q-_^KfLQ+zrx!2qaG^~B3< zZ)$p$q=IGhZZ)!ZD)BlL3K*UEp^%Q(4`)_c7+UM9G}z@8?_l>rv#l#qm+aDVIn5sR zV5}O=N?~Z13p#j{u#s~o|6n9OqHd+t#cM>klV*x^5)Bsg=&)@O7#Fc7@X=1C=w=o) zPI!AE4(h2SGMlMmW6_wfCWx@0#g%K8-+yu$UfYGq35!FwF3U3-x(qDA!ZgrDXKsh! z&{U{Ssz&NIJmt{x)ji@+a!yUBq7{8I!qsG)T((Q*mt~4AZ`S(IWeMo5CGW}}n2g-A zlgGZ)vLp?5`I2|Adm*EiOH!A`$YL28y|UkP*wT#M6X3oXTrx4UQUKG6PL8*fmbNeC zOl@julyToCvw?`GN!+)9hO6f(=Q_ncXjz@yIob}aBsmRXy7(F@F{jXxA%2f%XqdjI zFcjF=!BH0@S4r*0y`7wv1VJuE7T|!>)(W^O85JfgL5`*1li~64v9gF1d_ZIfar9pi!05%RGm<9Em0E~5Q_5>fW;B8teGcN7vQ7>TX3HM zQS<3o^P!2scS+<~ZZOJuMtBVa?4ogt%`yT#?~mkN^7@V4(y<^R`}Vw-;J<94T+Q>+ zxTOETK{=ms=8qm-7TM1~+!+2RyI*F^ca&MKued&cVH0hp`U>^z`He`i}H(MWbqUS9Q?AQwu z0GFTl0&p*k>Dcp9muz^aj4`pB_N?|Wd^^D6%2SGNyiu^f#aS}Q^(p*~I!|JNJ(b+y zaj@b-pmPBLvHz~%)*1T@BcmK{1K4hoV>>d~ckTie+``b_HPC%_hBJo?!c9_H;X06o zAy-(-ZAYHhcW`1uL=#2TDdP}c5w~>DLo}@5su~}MT!!caa!^JGrx0*?zF$qH=J3aY z+h=8iRegwK19#<_Z{NYpP1wjU5!tt$;4`N}Rj%lOt@f4x?>j(rsnfxp?GdVOW}-t} z>l%u|7w|EOf9=1UWBqsJDc9fj(r731*)(d!njTDQg~+u4?!KYr5+GBMrWHb?Ltle4 zOr2OK$y{h^M{mop%oV(_gxL!oSoY!qjzn7}j;&}i_7zgo_r=qkCP9OK5qR0;vunoMHBip_iNQpim#2`FGr2C*45UVmoT{(%#E`1%XJ(ZQG_NM1BwX+b=d0<7i z6^*-FU`2vxK}bB4*#ff4(2UE4mr^yW7I}PUUVVQi5*>%T7vjdV(suRMK3en0b1pEq z1(+wfr`ifFq(apd*fO$PK!1q|RR&X3E7igYEke7d0PmS$_@u8&V0vpG4fbtIU>1*> z0F!%4+Xv3tDIIQ}-UZB(!C+sJ4T5v09d^Rg?c^+e?)|uT0S5kqrg@0NFj*zeVZXy1 z4fL84R#nK zWJ}xGpON8-rm;$hKp*EG11k{`6{-@G7qh!KSG-ke4ZuFrpej4(&=71W@^-D%ZM$MfR2eR zEr%Pv$b%E&48bxMOXR(Kk4GO)Nz=W%h1t52%Cv3R1hdD5XRg4WQE?TI|E0n!4u7mh zIi#S9f{G}dnL?gQ)UhRv%ue0H&`y-!Ig!xGit?*$Q+9B)>vcP)mO!L+`MvDQU5pHD z^1XLYyQ3#>$Yaq-aKfMVCFiraYvq!8ELG_K zO=(t_H@wt%FRIY}8&a2Savgm-=k5(Fn^Wa;mur-`F^_xnkgYNZ)k6v`GgCUyLe^#6N%!=Tr3V(kBFzQ+}HhqH?7N)ctpn} z9RJJNeMB2_`?;CZRNOFgO@Y&rWmC5+i8|OR3Hf%Pm0}03sS#Z4lliE--TnR(x&$og z|Lw{rl~%L0x#ec_=bLAnezxhs#_@({U;%igzNzj^=zBwpwI8j$Uh|#7zYcZ>zADh? z|90P((Bv!qr>HW@4VLTbo8lM$=D7JPOa-5h2lDxq6HachXXI>JsxzPlA9*`mB>lxvX2 zpkI0x;Z4!8IE6bym0Nc8?JBrIt}>9{?;lmZL0W#+!HyAYa@la$(J-woWOzDMfFy@i zIy1Htl~wRYaJ??2-7&aMS5kw>#?CKEcoaO>Ec|}uU&`Qx*im9KGI-&H2ReswI?%bY znR$KgmhMI@@eZz)mH@mN@q+U85-2wq&vhn4(%$ZY#7XktKq69|ah6op`@fKSXAiI= z-gUsWB~Sm3lf&s>q^n%qfxT2XHWu>g0{*&X2oukF?e>m39JQ{;rYN7EItTMVrL0`g zQ@3%ORqrctM}Gf^^0o4W9$~HJ zsE%8x=80o4qbjI|@g0eA++!uQ^>sb#q>JB^L zpUcpWuoj>{))fwOd10cBqn`@kVHWzG%sQ@~1OS!s?yATEF9K$>A!8nFU9%YSr2HWN?Zc5*D27k9S_zI&GXfAfu-tL+7?o7g-($`46 zvin&>ImL=ZizQ@Mp;!Sfti!8?yLn2qaJ>L6kgga(g4lou$`qVpimnjPnj1WGAh|T1 zoLI2_+EBeqhH8=3W3mff?E|p;D9Siz3~f{(pgfF$_kcJ7oNy^Khf|L71X7mQX@2b`l6MiOor`#tuadx6W|ZBmGVR(> z?1y%hcSxPGtE{Gsbc-S$nMY-D`JxKxiV&k|B=&$AV=jeM5%|ZXYZ8Wa3eR8?)i@)Y&b4$ z4EIQW|B&)_d4dly%C9!kWnEl8X$91X`H^u2$E0U|8WLvfe0uW;{i--UrwY{|YoaQ= ze<_~m7xvFl2|`z8e^hxzhOdG7%Sg-8^05Hc5mSY<9Igm}iFm}#JH9ytJEXiHDIMN2 zD~+n1AaH61IYu~yex)cwbAtJRt!S5(j3^f7P9aK0M8H(!hB`7{BW#jJs*EdML3}@* zyo0c)TsS8b1JYQB;NVJg`Dg^$x?>XBq;~#6g-Kv;F~z&d+eGjYWVzi@Cgd1C3UX8f zaz~625L!riSsINEyoAwCcD1zu4T=&w$M1$xfkp#PVBvIR7twPiZ18V(L4yj5F2qFI zvg2xR7}Yc3K464r)Pr@4l0RCh^>|lOfBw@iN$@`Vwld_53qCBe>B$CJhq$>2E|9}N zE1^aBuKaJ4Mqn*3VhBX~ba9NtMKS;?K$x>I;$#(ch9H#?5zS87`9plZlrW^@#(^K0Yc~JsL`2P*1^+wCzx0ua;)bzO~rSZEOI~rc7FVwxS?p)|+ zLg#Al)_g;YiS-YFrM|yC9$pW8IgrYp!=Z43oDU> z@IJfr9KXLt`5|d^_8{X4(&XWhQN=UoK*{PYX!Ka*IkqWE_#KjK&bVI zUhc_k<1z<Q=XxkDj?uSKp^)6 zM4nN;PoClgqaxhAhI?c`PeCfz*6!$D3JEITBaO`bjPixK=rDZE zGxWW29+^zbSWHm)ZV8H~8PESE^FPNXYO~)1v6!IpT~hb#3RBCD(_*7%mQ^Htc^a^O zEbHRUYsU(01@S=VgJ_i_P63MuGN(Y>q_BPvJ#iUIQ@LO!Z)wR}xZbb;iV^BTiD9X; zvB5FxM!5*OU0Q~U;G!bGKd5}CG_j(eQR=acuD%JuJB3Y^3VRYiviSmB9nS~dJO7mX zC}(;VdFU_#;rHUByLSsDj49tCVSzowctSFnd^#Q9gH1mNJ-&rq|K5BL0p;6e`24K% zA7hUNly8%uc$z)r1>&{sv4HYEse6`S9p!t>Inj0(S@*rippd|~3MtUaL46xHxS?Fb zBSe0Dic7VlJZ=aa9*&R5w=OE}xPbDl@-z$V zL9ye`jnIiIt_Oyxu;-4;5bz?hYdJ#N2U#v3#`8qkBtrl9cea-LsMG~~eegUQl)hnT z5PI8pxVVR_Ko3S(&{XQl?+?V4Z;>b7$mo4B*Fn*l^RUo^?lB(j$hVjs*+@%i1#G&& zg~L(#W*MA%MlXZ8zzORfaP~O2%0f30xNSN0tT;>?sFIa^sEJp_hk$6uOuKz|L;(il-AXj-)niP`Ag03X$~}fW8?2P1{&Jx zf3Cix?uSEP485`T#oC8q1BeG-3Vf^obN;;Vr+xFv=h3X}2M*;~FFy%0%U9v(WP7YI z&+~w+5boR39zP6{Ls04moL`i~p6E|_g>WmUu!c}uF7x_>0p*va=|06~%9dPQeq zGm(Hhn+LNYP}#wl@*kuj*&w4-X{TYHSbK<@A=pC zpvSTX8@&kKz`X=h&#Yx5bTc$Q5FbOBrWcwEH5D_w?@zF7Qp6{Z$nNg>a zfRtajrMN~g9od62wQ1(`3{{;@Z0Q*UJEAq5hZETho!JxFH4~^96{)V{$bm*qlxS&P zeo3C=Fr(CL=epX(yvBu-feJK5j~Xbjk1NCa1`-F-Tb76pMoJX^IR)a#%Oh9gwPl2{WAtA ziO13KtR`<+Rv*)zQa&UDc%403z6AvfDJ^uhBb+X;15AElnm??F5K?+7R1*gkIB^0_ zS$g>_Yit-uoQ%sNxVwe48t~N8wWT3Saf!`^#4l+$sqmDnhFYNENeF?rDJ4vFbbQQu zfp*1tR@3%pa_40q1@t@w_eENKf1@a%g!rRa`qjM z-6`94ard_Kw8LslDnBOyoMLCZRM}WHph-5oVNsC&71}@f>9`ld$$t3A$>VvVtt5{9CDSw#g_r znbvi3Bu+BkDM1@lpySeWUJ_Yzehk0%I$UvhQv;b{$OS;rTZ(9-w3V0?OZ;K@ImaGn}Aajl+=UqW~!lEvM6RJIwYwzfO z)s9d=3|wp&Z+Hw_Aqv(GSAJTaSe#MPv#GB3K5j-nB_b54KonB(aJBKz;}h)F|G3ok9mb#DlDlrZZg;)A zmO(&sp?L~1G&q|#^J*+QZYBE2=y(Bja-q+S)E|>ZX1y!|tj4hKxNk!otldgp%@>{=j3(k4mtT{@-31P47m%HwnB(`2(q2cA7oyop+n{m{SSw zRn9&EFq|v)ACtIXm_sVKl)8$5m3}o&52Nq{YFrqqV5XWfLADxQI@3VJvBCJ5%aXe= zVdWlJyzUr3YLwrXL7rxm-)yd{eT47kE|EQl3IuV*Q!S$xa5iu5K&vH-!TAW(Ka!VA z8ic&l+_g)UohOBdtxP(Q&XJE32HOww=Aw^kt^L88OUfsuiQHySd6(3Vkam=+gHpUV zLYjl*4+fRrlLlswGM?*9-fLlI1xY(Y??LRAU|VL4B@xAbK6oz^-oEADpz^yC{7*5S z7)=J>5q#~fTJ6wSfkPqXHsS@9Pe`4!`xsA=CiiYXAOrgHf%8>2MT;AUP6?v&Buea%Ex6GB8<|VIbpfhqmVo* zw;HNa+8z4}A^CRrlpV=V73eJ0p(;Gv%i9Iz7q&kXFcxG6K4G&XSURAg(d46##(BnCFyZ8r&ZD@pG@wz&ZKqjpOWzf2(`(yMN%>Vs#7bhVD5Hl# zx%Ri|h=M5wr_lt`cHttghAlgWN0e~46c^O{dK^*L?O1<--6K)U?zJ&vm9U^shBfEVK1iS<;Dj$*X!j7>^YeU_!XyuUf(gkXCV9Jc&!WVxuAfdI}bRv-|@HK@|6%F<~Vii#yciSAIi= zYnd%#y&&eC;@4|Lylfj4s#2GO0&5u9CIqf4GYET~JPWfs<#CS>v<#WDQdA4jtC{$`!UM-l3G}b5{;*S0M(B@N7KBr9~06KRkpZ zVpQg!az!9*R{2#4x?(e%E5q3(ViCR0g9{( zGG{ocT5ciucwo%2tyr;Gd|2%zd)53Zr2L$*fnWr zMVwK7MTYqpyWm|!1UuP;Z!O}1#C_NtEe9WYzpL4fInKEehyzOw@efOIk^bMVyjN+R zM*hEe^N%zSHGOO2aKm#AHT9?K)m5^Lq*N zulLV3qnvUsKT~*`d$_R+y$1K04Bl%dIn8z5dk&{QU2`mHA*DQ{)S{3MDF0hRzzU;e zZmuEQ+dV@%%7indxYs>TX(u9fum4NxmA#qWMz?caZGDjAa9^66UfZca=5vp1u)1kN z05OpjS%aMzmUE<<#dD5*$OUl+vad5Lujj7pt-D&tCMeff#=K)_G5*>xJqP&_0p#v& z1J{)=$a839Jkgq4JgfKM#<2~QpxoKk&v4f~Y#KU1p>$$+SD2jjIz3clC7uy(X>MH` zijP1fx`Gk9T#P?jV$2;Y_-n3PciZ!=$2cn(rWdf@C3>mWSKvwIZ>5>sVmU9#wOz2# zT4vm)J~@?3_l&G@HZpqJ%t+R4@p(Y`8)-0B$0$u(S66#Gw{FKyQse_KSeLM%&N|j@ ztLJ&X0(@BUiTDizw{yQp?WIeY(V;#Vxko84?Pv_{ywAC-2d*i9Ey1$JD6N`nueF^O zEKppqk~M4Usq)078FQW%Q2t69guRhbRt>U3o%PMuktUM(d~g*W(IR8E_Uho_M*0gh!#q`Ho@g?ShwD7UJBO!>S7 z$qkk%1H|jy-p`q2pVJ}bFQwku(=6>h^TI{zKD;Bi?9a~R@iMp)9XU^CoK-zBVi0kR zwKW|gjPoEX68soY{qT_;5)O{T^+;H&s)Up|HrbH!-=+CI#VGrjYZ+gQ9V(>!h15Bl zW|Zv<`dXN?yXoWpagZ?;s>=VH5~%4*GwX0vd4lR1$pL9r}Zz;1%W1<%wKilsjy$>tyT%@omeMOoa*) zk$u73dk4!|Fs#c!=yyqj*~*bei|LM|;f5zoPu9 zG%JH`dI`C0TJUAqF)&S;FFQc<0|~kq#AAAD12QOo+m_CGJP2HhIteL%B8|l!XOuT= z8g$)97__2Jr~nLlB_lobJg+#t!sJ9BYkP8w42|0r^2L#q5nJi9^aGP3@eJzO1x zu?taQ(AxyM7KpSa(uh5V?!@CQR3769({L0iHYmx$2qGI1ffTQx5Q@-1(g(BDd4=#m zXi-*d;9%+a_-_)L*-=J`)J`Ef%=XKJUfQTYSoR2fV)^-Iq1dE~#38=j#ZAv!$vw?S zu=zn?SHi9*80DLbcZjMOUsRT>(R3j8%OSt1Xc_D1#U zM-B>H`jO`#-Rr<~j=oyuvl1k$jGh;BA>rQKk+sTaq;6S^X=OW$UGvP)Q8}C$a#qx{ zRG@Q=d`PH9CBptU4&mbo+c|$;q+>wEIY><2kfviXxE1Z)MBOu`8_>9Hf z>Ch+3U6QrRr)99?>{*}#QLHiU@V8RoP@y`dQqT}Cn2C!eSMZ3IGaEsG;x}FbD^;Uf zD?4!Jz_gO507&%)KGHHb5otrA*hAwm%x&?8Ra*OlwM)t$NuZPd|A6unO6#L7-_UZr z`D)V-HU3@WcQ+nyNY%fkZlmrS~fQ*80lUYlmA zPi3c-BR$)v)rAaR7PpjAmHYw;z65xLfZifE@|)c}Vv^dyR8ZqEX!Ci;b?5CNPiuV$r|RpYg6BP=`tH=`muL2Q`J{i1s1Mu8gL z*DS+(r`QB|2f8}KM~-)&I(1E`%yV?smZxa4`{Y_BSSXbx70>~ce=lB>F3pYG`b))L zK8ZeC>E1<|M&ce>o2Jy#Zf$x)z~dP}^^6dxN&1@PsVl{@R9xYNO(kAd=}gtdD)|fU z_Nx#OYeb?~Vl8$y)rB1`lh(2~ZG%P7*C@fTQf%-NYP-mULfFS9f>>3}w6vPl^-`}5 zQm^b{k@9X`0uge@fRLx#seo?rA9cI3xI}m6G#xX8jX+(40AXcG6@>0c)Xvncnq>_< z^~=TS_&6lb1ocx$@6|dG^e24v@@!^{b!8mC=(88fD}2%@C8L0a{(eepJUcOohv@J? zUO|tjPh@pOS%RH&;e+@+*gEcrP^>i_2lR*zUJQq8mV!^t;BFNVq-F8UEF6C$B)d`# znh<8ewXH69-pu_|+YVzz;xuTGxJV387e!9o?~@QZim+o^zMslpzdm4Twe2duX5 zq>7>*p1c=g87ad^_<}TjP}m-G7J{52WsL&CHD5@E`Gq2-{SqwV#K3^XA}(-4WtosA z`c^j%z#=3F_!gGR4rT3aa`SSDGMk$P(q2OZf|{>ZntrjRSeBnGTzt_czkTwFOj0Ai z@x$B+f;)R9Yr_Qg9e)i8BD@9<2;NfWT!olO#@H@o$@s8AX=isdo73!;zS*JpCFGSg zwxIy%xCOAfsqu^;{pwZRV9#i2JW(N@X*0Akn}~ho!(mzhUyX#3m16L9VOjxSQ0kTC zij+*uK8ZR&X1E90KEsW-)wae#$Sle8HbbvTa}p>O4@_GyV}{=4vU&)r7r$-VuRKa1 z@T@N&O}p4oEYDH)N{xrS9C?<8C~-I0b;ph8oG&?!I?@i&fVjIzS`F!zD{{+~9}EU> zgXTt$zkzMU@ehAFxZIUjx<>+0dFIuh&gOFix5->)+{$pHZy#MOLll02ydR)Go_KEp z=>gC^JxDwbDO4aU5b*gWY_O9>%8|y2fJ8_bPhbRU50gUon#GYSmT7L3=8Lt#aU}o~ zzpb^Lj|2kG_gR;BO7*DI9bIAs}6hjAB_H@Zm@)NmZ! zzgYlP)-K?`q5K~i%r~0`ljcDBf1C1kO6y3=r&^Yqzoz+E(?a7H5CPCz|6<)A)b)ja zwDuQjJ|6tz;Cln_^}pTsDPO|Zgz8_upSwj$edS_z;T42wiKCdc4C+(y< z-wA1K)?K7LQR7e+xn)X3Hnmfs!cMZtCaod);x-;=PUF^I_Idi$WFQ@7AY^*>#zeu zzAmX-mMXpx{fnVH zav)Tb@x4)=?|P9kS50(5)+SCJ@APY!FBPiKH=}RRB?Qde>A1?ycWOEnd4P;J9s9^2 zr89!am*lfVRzMwd9NzyN%f|_Xf@xo;JkzZrrN?qCvN(cwFjLDSTV-X};Y@=A3=c2w z$~sEbIP2K;=miIO#3d=(o^m*8={Z()Vy%$JJv0N((FFWcChqHz1~xwE;**{s1EE05 z_mDK*yTuOg{wSHgy`jp^>T%QemCP1m!ZIq7=|cX>!_)-}7aX0}vh=bvvZM>s)}#}B z74rK!q@meh@j-!U@h0L~p1PPnkXW6jws37jgyhcP4oZrwk(G`Tr<$l$<>`-cAo&Fs zx;*UyA^(=|K^f+UitS~_Mb>At)7{TrS^X{`Z@WpxLj?Lz4u%9`S+Obe26oFuT&RG? zuwbSNiG+r?WO3ubM5t(n{)bmKf$a1mItc!Rkgr{aFI=Q-SPu0fR0+G)rrs(B!xD`g zgnB&@OcWDn88HP$)l#>d+5@3mzBfpq7VCQkay4?o}gbZJsgbVgqu`IDynBBC=>dSFT+eSmLt+2w;O3JHa z8+6XuTHgcG3|5QuGS=&G&er;lOWm@m;xY6)hjRzAn;q)YP6hCY3iM2>ti23~d*wv3 zhvw^4ch}yR@>`{Cpmx)DOqxc%c+`uXwi}++tzxs2)`R#0XR>@&(pA5ChsDp(Cl+tq?o)Nc6hm#0x*JmN*)L5*ko z_PjY3Z->xpd`Bc;R*Lj^m7CD-#RXU6yHDzs%@kYF?E7~KkjRQ6QbziGA71IA3>R(V=A8vVf^RG4EZu;5AziPbR z@GT8f`0()O5Cex8IK;pq1`aWBh=D^49Ae=0!~i>AJngkkgqh({H1=>aSFq^Xx}qIm zB8;?^46~D)aKj(U^yu`gX>P$9j0DkSsl)4>OMPcci-lU`y+Dwf6{A+VMcLr%m*JW& z($i1SGEZ?g`ZW>5N(G>`ik2zn3}csNc9B&U$t7f0JD72XUs4^WN7mi;8BcG4hMez| zJdv^DBjwxY@|S!2+~bz1_dO!PzFO>ieYQ-!uTSb0TZZ0zI>?rRrU@3XuPtL$y@xGR zrDdQ#>+6-lf1nulu4o(U^sH!=Ed$Go0eg$Z{AknY8Q6z&o(*FK$O*#?!tOG-KO=7` z%gYj|_l0H1t`_Opr_!dG=7!qcH;tI}o}0$5b}yS|a8x7;7EYy+&%FHg$j#&4MV{^h M_0RbpmS*<<0AF|jiU0rr literal 0 HcmV?d00001 diff --git a/ai_ta_backend/service/poi_agent_service.py b/ai_ta_backend/service/poi_agent_service.py new file mode 100644 index 00000000..17cf8b62 --- /dev/null +++ b/ai_ta_backend/service/poi_agent_service.py @@ -0,0 +1,223 @@ +from dotenv import load_dotenv +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_openai import ChatOpenAI, OpenAI +from langchain_community.agent_toolkits import create_sql_agent +from langchain_core.prompts import ( + ChatPromptTemplate, + FewShotPromptTemplate, + MessagesPlaceholder, + PromptTemplate, + SystemMessagePromptTemplate, +) + +def initialize_agent(): + load_dotenv() + + # Set up the Database + db = SQLDatabase.from_uri("sqlite:///plants_of_India_demo.db") + + # Define the system prefixes and examples + system_prefix_asmita = """Given the following user question, corresponding SQL query, and SQL result, answer the user question. + Return the SQL query and SQL result along with the final answer. + Limit your search to the "Plants" table. + Restrict to data found in the "plants" table only. If the SQL result is empty, mention that the query returned 0 rows. Do not come up with answers from your own knowledge. + If the SQL result is empty, mention that the query returned 0 rows. + Any SQL queries related to string comparison like location, author name, etc. should use LIKE statement. + Refer to the description below to obtain additional information about the columns. + Column Description: + 1. recordType: This column contains text codes which indicate what information is present in the current row about the plant. + FA = Family Name, Genus Name, Sciecific Name. + TY = Type + GE = Genus Name + AN = Family Name (contains Accepted Name), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, and Year of Publication + HB = Habit + DB = Distribution or location of the plant + RE = Remarks + SN = Family Name (contains Synonym), Genus Name, Sciecific Name, Author Name, Publication, Volume:Page, and Year of Publication + + 2. familyName: This columns contains the Family Name of the plant. A family is made of several genera and genus is made of several species (scientific names). + 3. genusName: This columns contains the Genus Name of the plant. + 4. scientificName: This columns contains the Scientific Name of the plant species. + 5. publication: Name of the journal or book where the plant discovery information is published. Use LIKE statement when creating SQL queries related to publication name. + 6. volumePage: The volume and page number of the publication containing the plant information. + 7. yearOfpublication: The year in which the plant discovery information was published or the year the plant was found / discovered. + 8. author: This column may contain multiple authors separated by &. When creating SQL queries related to authors or names, always use a LIKE statement. + 9. additionalDetail2: This column contains 4 types of information - type, habit, distribution and remarks. Use LIKE statement when creating SQL queries related to any of the three fields of additional details. + - Type mentions information about location. + - Remarks mentions location information about where the plant is cultivated or native to. + - Distribution mentions the locations where the plant is distributed. It may contain multiple locations, so always use a LIKE statement when creating SQL queries related to distribution or plant location. + + 10. groups: This column contains one of the two values - Gymnosperms and Angiosperms. + 11. familyNumber: This column contains a number which is assigned to the family. + 12. genusNumber: This coulumn contains a number which is assigned to a genus within a family. + 13. acceptedNameNumber: This coulumn contains a number which is assigned to a accepted name of a genus within a family + 14. synonymNumber: This coulumn contains a number which is assigned to a synonym name associated with a an accepted name of a genus within a family +DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. + +If the question does not seem related to the database, just return "I don't know" as the answer. + +Here are some examples of user inputs and their corresponding SQL queries: +""" + + + examples = [ + {"input": "List all the accepted names under the family 'Gnetaceae'.", + "query": """ +SELECT DISTINCT scientificName FROM plants +JOIN ( + SELECT familyNumber, genusNumber, acceptedNameNumber + FROM plants + WHERE familyNumber IN ( + SELECT DISTINCT familyNumber FROM plants WHERE familyName = 'Gnetaceae' + ) +) b +ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber +WHERE plants.recordType = 'AN'; +"""}, + { + "input": "List all the accepted species that are introduced.", + "query": """ +SELECT DISTINCT scientificName FROM plants +JOIN ( + SELECT familyNumber, genusNumber, acceptedNameNumber + FROM plants + WHERE recordType = 'RE'and additionalDetail2 LIKE '%cultivated%' + +) b +ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber +WHERE plants.recordType = 'AN'; +""", + }, + { + "input": "List all the accepted names with type 'Cycad'", + "query": """ +SELECT DISTINCT scientificName FROM plants +JOIN ( + SELECT familyNumber, genusNumber, acceptedNameNumber + FROM plants + WHERE recordType = 'HB'and additionalDetail2 LIKE '%Cycad%' + +) b +ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber +WHERE plants.recordType = 'AN'; +""", + }, + { + "input": "List all the accepted names under the genus 'Cycas' with more than two synonyms.", + "query": """ +SELECT DISTINCT scientificName FROM plants +JOIN ( + SELECT familyNumber, genusNumber, acceptedNameNumber + FROM plants + WHERE genusNumber IN ( + SELECT DISTINCT genusNumber FROM plants WHERE GenusName = 'Cycas' + ) + AND familyNumber IN ( + SELECT DISTINCT familyNumber FROM plants WHERE GenusName = 'Cycas' + ) + AND synonymNumber > 2 +) b +ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber +WHERE plants.recordType = 'AN'; +""", + }, + { + "input":'List all the accepted names published in Asian J. Conservation Biol.', + "query": """ + SELECT DISTINCT scientificName + FROM plants + WHERE recordType = 'AN' AND publication LIKE '%Asian J. Conservation Biol%'; + +""", + }, + { + "input": 'List all the accepted names linked with endemic tag.', + "query": """ +SELECT DISTINCT scientificName FROM plants +JOIN ( + SELECT familyNumber, genusNumber, acceptedNameNumber + FROM plants + WHERE recordType = 'DB'and additionalDetail2 LIKE '%Endemic%' + +) b +ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber +WHERE plants.recordType = 'AN'; +""", + }, + { + "input": 'List all the accepted names that have no synonyms.' , + "query": """ +SELECT DISTINCT a.scientificName FROM plants a +group by a.familyNumber,a.genusNumber,a.acceptedNameNumber +HAVING SUM(a.synonymNumber) = 0 AND a.acceptedNameNumber > 0; +""", + }, + { + "input": 'List all the accepted names authored by Roxb.', + "query": """ +SELECT scientificName +FROM plants +WHERE recordType = 'AN'and actualAuthorName LIKE '%Roxb%'; +""", + }, + { + "input": 'List all genera within each family', + "query": """ +SELECT familyName, genusName +FROM plants +WHERE recordType = 'GE'; +""", + }, + { + "input": 'Did Minq. discovered Cycas ryumphii?', + "query": """SELECT + CASE + WHEN EXISTS ( + SELECT 1 + FROM plants as a + WHERE a.scientificName = 'Cycas rumphii' + AND a.author = 'Miq.' + ) THEN 'TRUE' + ELSE 'FALSE' + END AS ExistsCheck; +""", + }, +] + + # Define the prompt template + example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}") + prompt = FewShotPromptTemplate( + examples=examples[:5], + example_prompt=example_prompt, + prefix=system_prefix_asmita, + suffix="", + input_variables=["input", "top_k", "table_info"], + ) + + # Define the full prompt structure + full_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate(prompt=prompt), + ("human", "{input}"), + MessagesPlaceholder("agent_scratchpad"), + ] + ) + + # Initialize the OpenAI models + llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) + + # Create the SQL agent + agent = create_sql_agent( + llm=llm, + db=db, + prompt=full_prompt, + verbose=True, + agent_type="openai-tools", + ) + + return agent + +def generate_response(user_input): + agent = initialize_agent() + output = agent.invoke(user_input) + return output diff --git a/plants_of_India_demo.db b/plants_of_India_demo.db new file mode 100644 index 00000000..e69de29b From bb88811190106599c7afe5f0eb31a687c5bbf22e Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Tue, 18 Jun 2024 17:20:42 -0700 Subject: [PATCH 04/11] Initial commit; working 'docker compose up' command --- Dockerfile | 26 +++++++++++++++++++ ai_ta_backend/main.py | 3 ++- docker-compose.yaml | 55 ++++++++++++++++++++++++++++++++++++++++ init-scripts/init-db.sql | 15 +++++++++++ railway.json | 9 ++++--- 5 files changed, 104 insertions(+), 4 deletions(-) create mode 100644 Dockerfile create mode 100644 docker-compose.yaml create mode 100644 init-scripts/init-db.sql diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..5eda60c9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,26 @@ +# Use an official Python runtime as a parent image +FROM python:3.10-slim + +# Set the working directory in the container +WORKDIR /usr/src/app + +# Copy the local directory contents into the container +COPY . . + +# Install any needed packages specified in requirements.txt +RUN pip install --no-cache-dir -r requirements.txt + +# Set the Python path to include the ai_ta_backend directory +ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/ai_ta_backend" + +RUN echo $PYTHONPATH +RUN ls -la /usr/src/app/ + +# Make port 8000 available to the world outside this container +EXPOSE 8000 + +# Define environment variable for Gunicorn to bind to 0.0.0.0:8000 +ENV GUNICORN_CMD_ARGS="--bind=0.0.0.0:8000" + +# Run the application using Gunicorn with specified configuration +CMD ["gunicorn", "--workers=1", "--threads=100", "--worker-class=gthread", "ai_ta_backend.main:app", "--timeout=1800", "--bind=0.0.0.0:8000"] diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 6c3237da..5175ee70 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -504,7 +504,8 @@ def configure(binder: Binder) -> None: vector_bound = True if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_BUCKET_NAME"]) or any(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): - logging.info("Binding to AWS S3 storage") + if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"): logging.info("Binding to AWS storage") + elif os.getenv("MINIO_ACCESS_KEY") and os.getenv("MINIO_SECRET_KEY"): logging.info("Binding to Minio storage") binder.bind(AWSStorage, to=AWSStorage, scope=SingletonScope) storage_bound = True diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 00000000..33ae37ee --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,55 @@ +version: '3.8' + +services: + # sqlite: + # image: nouchka/sqlite3 + # volumes: + # - sqlite-data:/root/db + # - ./init-scripts:/docker-entrypoint-initdb.d # Mount initialization scripts + # command: [ "sqlite3", "/root/db/sqlite.db", "-cmd", ".tables" ] + + redis: + image: redis:latest + ports: + - "6379:6379" + volumes: + - redis-data:/data + + qdrant: + image: generall/qdrant:latest + volumes: + - qdrant-data:/qdrant/storage + ports: + - "6333:6333" # HTTP API + - "6334:6334" # gRPC API + + minio: + image: minio/minio + environment: + MINIO_ROOT_USER: 'minioadmin' # Customize access key + MINIO_ROOT_PASSWORD: 'minioadmin' # Customize secret key + command: server /data + ports: + - "9000:9000" # Console access + - "9001:9001" # API access + volumes: + - minio-data:/data + + flask_app: + # build: . # Directory with Dockerfile for Flask app + image: kastanday/ai-ta-backend:gunicorn + ports: + - "8000:8000" + volumes: + - ./db:/app/db # Mount local directory to store SQLite database + depends_on: + # - sqlite + - redis + - qdrant + - minio + +volumes: + # sqlite-data: + redis-data: + qdrant-data: + minio-data: diff --git a/init-scripts/init-db.sql b/init-scripts/init-db.sql new file mode 100644 index 00000000..49b1b646 --- /dev/null +++ b/init-scripts/init-db.sql @@ -0,0 +1,15 @@ +CREATE TABLE public.documents ( + id BIGINT GENERATED BY DEFAULT AS IDENTITY, + created_at TIMESTAMP WITH TIME ZONE NULL DEFAULT NOW(), + s3_path TEXT NULL, + readable_filename TEXT NULL, + course_name TEXT NULL, + url TEXT NULL, + contexts JSONB NULL, + base_url TEXT NULL, + CONSTRAINT documents_pkey PRIMARY KEY (id) +) TABLESPACE pg_default; + +CREATE INDEX IF NOT EXISTS documents_course_name_idx ON public.documents USING hash (course_name) TABLESPACE pg_default; + +CREATE INDEX IF NOT EXISTS documents_created_at_idx ON public.documents USING btree (created_at) TABLESPACE pg_default; diff --git a/railway.json b/railway.json index d6d92535..4147197e 100644 --- a/railway.json +++ b/railway.json @@ -9,11 +9,14 @@ "cmds": [ "python -m venv --copies /opt/venv && . /opt/venv/bin/activate", "pip install pip==23.3.1", - "pip install -r requirements.txt" + "pip install -r ai_ta_backend/requirements.txt" ] }, "setup": { - "nixPkgs": ["python310", "gcc"] + "nixPkgs": [ + "python310", + "gcc" + ] } } } @@ -23,4 +26,4 @@ "restartPolicyType": "ON_FAILURE", "restartPolicyMaxRetries": 1 } -} +} \ No newline at end of file From e8c96dd2460f51552f2ed8cc394c4d07fe23a67a Mon Sep 17 00:00:00 2001 From: Rohan Salvi Date: Wed, 19 Jun 2024 00:23:26 -0400 Subject: [PATCH 05/11] POI SQL_Agent with Langrapgh and dynamic few shot template. --- ai_ta_backend/database/poi_sql.py | 14 + ai_ta_backend/main.py | 14 +- ai_ta_backend/service/poi_agent_service.py | 501 ++++++++++++++++----- 3 files changed, 404 insertions(+), 125 deletions(-) create mode 100644 ai_ta_backend/database/poi_sql.py diff --git a/ai_ta_backend/database/poi_sql.py b/ai_ta_backend/database/poi_sql.py new file mode 100644 index 00000000..8cf94d4e --- /dev/null +++ b/ai_ta_backend/database/poi_sql.py @@ -0,0 +1,14 @@ +from typing import List +from injector import inject +from flask_sqlalchemy import SQLAlchemy +import ai_ta_backend.model.models as models +import logging + +from ai_ta_backend.model.response import DatabaseResponse + +class POISQLDatabase: + + @inject + def __init__(self, db: SQLAlchemy): + logging.info("Initializing SQLAlchemyDatabase") + self.db = db \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 41d8edde..38bbc12c 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -3,7 +3,7 @@ import os import time from typing import List - +from ai_ta_backend.database.poi_sql import POISQLDatabase from dotenv import load_dotenv from flask import ( Flask, @@ -46,6 +46,7 @@ from ai_ta_backend.extensions import db +from langchain_core.messages import HumanMessage, SystemMessage app = Flask(__name__) CORS(app) @@ -228,12 +229,16 @@ def createConversationMap(service: NomicService): def query_sql_agent(): data = request.get_json() user_input = data.get('query') + system_message = SystemMessage(content="you are a helpful assistant and need to provide answers in text format about the plants found in India. If the Question is not related to plants in India answer 'I do not have any information on this.'") + if not user_input: return jsonify({"error": "No query provided"}), 400 try: - response = generate_response(user_input) - return jsonify({"response": response}), 200 + user_01 = HumanMessage(content=user_input) + inputs = {"messages": [system_message,user_01]} + response = generate_response(inputs) + return str(response), 200 except Exception as e: return jsonify({"error": str(e)}), 500 @@ -516,7 +521,8 @@ def configure(binder: Binder) -> None: binder.bind(SQLAlchemyDatabase, to=db, scope=SingletonScope) sql_bound = True break - + # if os.getenv(POI_SQL_DB_NAME): # type: ignore + # binder.bind(SQLAlchemyDatabase, to=POISQLDatabase, scope=SingletonScope) # Conditionally bind databases based on the availability of their respective secrets if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any(os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): logging.info("Binding to Qdrant database") diff --git a/ai_ta_backend/service/poi_agent_service.py b/ai_ta_backend/service/poi_agent_service.py index 17cf8b62..b249a5f7 100644 --- a/ai_ta_backend/service/poi_agent_service.py +++ b/ai_ta_backend/service/poi_agent_service.py @@ -9,215 +9,474 @@ PromptTemplate, SystemMessagePromptTemplate, ) +import os +import logging +from flask_sqlalchemy import SQLAlchemy -def initialize_agent(): - load_dotenv() - - # Set up the Database - db = SQLDatabase.from_uri("sqlite:///plants_of_India_demo.db") - - # Define the system prefixes and examples - system_prefix_asmita = """Given the following user question, corresponding SQL query, and SQL result, answer the user question. - Return the SQL query and SQL result along with the final answer. - Limit your search to the "Plants" table. - Restrict to data found in the "plants" table only. If the SQL result is empty, mention that the query returned 0 rows. Do not come up with answers from your own knowledge. - If the SQL result is empty, mention that the query returned 0 rows. - Any SQL queries related to string comparison like location, author name, etc. should use LIKE statement. - Refer to the description below to obtain additional information about the columns. - Column Description: - 1. recordType: This column contains text codes which indicate what information is present in the current row about the plant. - FA = Family Name, Genus Name, Sciecific Name. - TY = Type - GE = Genus Name - AN = Family Name (contains Accepted Name), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, and Year of Publication - HB = Habit - DB = Distribution or location of the plant - RE = Remarks - SN = Family Name (contains Synonym), Genus Name, Sciecific Name, Author Name, Publication, Volume:Page, and Year of Publication - - 2. familyName: This columns contains the Family Name of the plant. A family is made of several genera and genus is made of several species (scientific names). - 3. genusName: This columns contains the Genus Name of the plant. - 4. scientificName: This columns contains the Scientific Name of the plant species. - 5. publication: Name of the journal or book where the plant discovery information is published. Use LIKE statement when creating SQL queries related to publication name. - 6. volumePage: The volume and page number of the publication containing the plant information. - 7. yearOfpublication: The year in which the plant discovery information was published or the year the plant was found / discovered. - 8. author: This column may contain multiple authors separated by &. When creating SQL queries related to authors or names, always use a LIKE statement. - 9. additionalDetail2: This column contains 4 types of information - type, habit, distribution and remarks. Use LIKE statement when creating SQL queries related to any of the three fields of additional details. - - Type mentions information about location. - - Remarks mentions location information about where the plant is cultivated or native to. - - Distribution mentions the locations where the plant is distributed. It may contain multiple locations, so always use a LIKE statement when creating SQL queries related to distribution or plant location. - - 10. groups: This column contains one of the two values - Gymnosperms and Angiosperms. - 11. familyNumber: This column contains a number which is assigned to the family. - 12. genusNumber: This coulumn contains a number which is assigned to a genus within a family. - 13. acceptedNameNumber: This coulumn contains a number which is assigned to a accepted name of a genus within a family - 14. synonymNumber: This coulumn contains a number which is assigned to a synonym name associated with a an accepted name of a genus within a family -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. - -If the question does not seem related to the database, just return "I don't know" as the answer. - -Here are some examples of user inputs and their corresponding SQL queries: -""" - - - examples = [ - {"input": "List all the accepted names under the family 'Gnetaceae'.", +from langchain_openai import ChatOpenAI + +from operator import itemgetter + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from langchain_community.agent_toolkits import create_sql_agent +from langchain.tools import BaseTool, StructuredTool, Tool, tool +import random +from langgraph.prebuilt.tool_executor import ToolExecutor +from langchain.tools.render import format_tool_to_openai_function + + +from typing import TypedDict, Annotated, Sequence +import operator +from langchain_core.messages import BaseMessage + +from langchain_core.agents import AgentFinish +from langgraph.prebuilt import ToolInvocation +import json +from langchain_core.messages import FunctionMessage +from langchain_community.utilities import SQLDatabase +from langchain_community.vectorstores import FAISS +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings + +load_dotenv() + + +def get_dynamic_prompt_template(): + + examples = [ + { + "input": "How many accepted names are only distributed in Karnataka?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "Karnataka%"));' + }, + { + "input": "How many names were authored by Roxb?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Author_Name" LIKE "%Roxb%" AND "Record_Type_Code" IN ("AN", "SN"));' + }, + { + "input": "How many species have distributions in Myanmar, Meghalaya and Andhra Pradesh?", + "query": 'SELECT COUNT(*) FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names common to Myanmar, Meghalaya, Odisha, Andhra Pradesh.", + "query": 'SELECT DISTINCT Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names that represent 'tree'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND "Additional_Details_2" LIKE "%tree%");' + }, + { + "input": "List the accepted names linked with Endemic tag.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%");' + }, + { + "input": "List the accepted names published in Fl. Brit. India [J. D. Hooker].", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" in ("AN", "SN") AND ("Publication" LIKE "%Fl. Brit. India [J. D. Hooker]%" OR "Publication" LIKE "%[J. D. Hooker]%" OR "Publication" LIKE "%Fl. Brit. India%");' + }, + { + "input": "How many accepted names have ‘Silhet’/ ‘Sylhet’ in their Type?", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "TY" AND ("Additional_Details_2" LIKE "%Silhet%" OR "Additional_Details_2" LIKE "%Sylhet%"));' + }, + { + "input": "How many species were distributed in Sikkim and Meghalaya?", + "query": 'SELECT COUNT(*) AS unique_pairs FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "List the accepted names common to Kerala, Tamil Nadu, Andhra Pradesh, Karnataka, Maharashtra, Odisha, Meghalaya and Myanmar.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%" AND "Additional_Details_2" LIKE "%Kerala%" AND "Additional_Details_2" LIKE "%Tamil Nadu%" AND "Additional_Details_2" LIKE "%Karnataka%" AND "Additional_Details_2" LIKE "%Maharashtra%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Jammu & Kashmir, Himachal, Nepal, Sikkim, Bhutan, Arunachal Pradesh and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Sikkim%" AND "Additional_Details_2" LIKE "%Bhutan%" AND "Additional_Details_2" LIKE "%Arunachal Pradesh%" AND "Additional_Details_2" LIKE "%China%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Austria, Belgium, Czechoslovakia, Denmark, France, Greece, Hungary, Italy, Moldava, Netherlands, Poland, Romania, Spain, Switzerland, Jammu & Kashmir, Himachal, Nepal, and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Austria%" AND "Additional_Details_2" LIKE "%Belgium%" AND "Additional_Details_2" LIKE "%Czechoslovakia%" AND "Additional_Details_2" LIKE "%China%" AND "Additional_Details_2" LIKE "%Denmark%" AND "Additional_Details_2" LIKE "%Greece%" AND "Additional_Details_2" LIKE "%France%" AND "Additional_Details_2" LIKE "%Hungary%" AND "Additional_Details_2" LIKE "%Italy%" AND "Additional_Details_2" LIKE "%Moldava%" AND "Additional_Details_2" LIKE "%Netherlands%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Romania%" AND "Additional_Details_2" LIKE "%Spain%" AND "Additional_Details_2" LIKE "%Switzerland%"));' + }, + { + "input": "List the species which are distributed in Sikkim and Meghalaya.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "How many species are common to America, Europe, Africa, Asia, and Australia?", + "query": 'SELECT COUNT(*) AS unique_pairs IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%America%" AND Additional_Details_2 LIKE "%Europe%" AND "Additional_Details_2" LIKE "%Africa%" AND "Additional_Details_2" LIKE "%Asia%" AND "Additional_Details_2" LIKE "%Australia%"));' + }, + { + "input": "List the species names common to India and Myanmar, Malaysia, Indonesia, and Australia.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number","Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%India%" AND Additional_Details_2 LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Malaysia%" AND Additional_Details_2 LIKE "%Indonesia%" AND Additional_Details_2 LIKE "%Australia%"));' + }, + { + "input": "List all plants which are tagged as urban.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Urban" = "YES";' + }, + { + "input": "List all plants which are tagged as fruit.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Fruit" = "YES";' + }, + { + "input": "List all plants which are tagged as medicinal.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Medicinal" = "YES";' + }, + { + "input": "List all family names which are gymnosperms.", + "query": 'SELECT DISTINCT "Family_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Gymnosperms";' + }, + { + "input": "How many accepted names are tagged as angiosperms?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Angiosperms";' + }, + { + "input": "How many accepted names belong to the 'Saxifraga' genus?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Genus_Name" = "Saxifraga";' + }, + { + "input": "List the accepted names tagged as 'perennial herb' or 'climber'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND ("Additional_Details_2" LIKE "%perennial herb%" OR "Additional_Details_2" LIKE "%climber%"));' + }, + { + "input": "How many accepted names are native to South Africa?", + "query": 'SELECT COUNT(*) FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%native%" AND "Additional_Details_2" LIKE "%south%" AND "Additional_Details_2" LIKE "%africa%");' + + }, + { + "input": "List the accepted names which were introduced and naturalized.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%introduced%" AND "Additional_Details_2" LIKE "%naturalized%");' + }, + { + "input": "List all ornamental plants.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%ornamental%");' + }, + { + "input": "How many plants from the 'Leguminosae' family have a altitudinal range up to 1000 m?", + "query": 'SELECT COUNT(*) FROM plants WHERE "Record_Type_Code" = "AL" AND "Family_Name" = "Leguminosae" AND "Additional_Details_2" LIKE "%1000%";' + }, + { + "input": "List the accepted names linked with the 'endemic' tag for Karnataka.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%" AND "Additional_Details_2" LIKE "%Karnataka%");' + }, + {"input": "List all the accepted names under the family 'Gnetaceae'.", "query": """ -SELECT DISTINCT scientificName FROM plants +SELECT DISTINCT "Scientific_Name" FROM plants JOIN ( - SELECT familyNumber, genusNumber, acceptedNameNumber + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants - WHERE familyNumber IN ( - SELECT DISTINCT familyNumber FROM plants WHERE familyName = 'Gnetaceae' + WHERE "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Family_Name" = "Gnetaceae" ) ) b -ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber -WHERE plants.recordType = 'AN'; +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; """}, { "input": "List all the accepted species that are introduced.", "query": """ -SELECT DISTINCT scientificName FROM plants +SELECT DISTINCT "Scientific_Name" FROM plants JOIN ( - SELECT familyNumber, genusNumber, acceptedNameNumber + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants - WHERE recordType = 'RE'and additionalDetail2 LIKE '%cultivated%' - + WHERE "Record_Type_Code" = 'RE'and "Additional_Details_2" LIKE '%cultivated%' ) b -ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber -WHERE plants.recordType = 'AN'; +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; """, }, { "input": "List all the accepted names with type 'Cycad'", "query": """ -SELECT DISTINCT scientificName FROM plants +SELECT DISTINCT "Scientific_Name" FROM plants JOIN ( - SELECT familyNumber, genusNumber, acceptedNameNumber + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants - WHERE recordType = 'HB'and additionalDetail2 LIKE '%Cycad%' + WHERE "Record_Type_Code" = 'HB'and "Additional_Details_2" LIKE '%Cycad%' ) b -ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber -WHERE plants.recordType = 'AN'; +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; """, }, { "input": "List all the accepted names under the genus 'Cycas' with more than two synonyms.", "query": """ -SELECT DISTINCT scientificName FROM plants +SELECT DISTINCT "Scientific_Name" FROM plants JOIN ( - SELECT familyNumber, genusNumber, acceptedNameNumber + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants - WHERE genusNumber IN ( - SELECT DISTINCT genusNumber FROM plants WHERE GenusName = 'Cycas' + WHERE "Genus_Number" IN ( + SELECT DISTINCT "Genus_Number" FROM plants WHERE "Genus_Name" = 'Cycas' ) - AND familyNumber IN ( - SELECT DISTINCT familyNumber FROM plants WHERE GenusName = 'Cycas' + AND "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Genus_Name" = 'Cycas' ) - AND synonymNumber > 2 + AND "Synonym_Number" > 2 ) b -ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber -WHERE plants.recordType = 'AN'; +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; """, }, { "input":'List all the accepted names published in Asian J. Conservation Biol.', "query": """ - SELECT DISTINCT scientificName + SELECT DISTINCT "Scientific_Name" FROM plants - WHERE recordType = 'AN' AND publication LIKE '%Asian J. Conservation Biol%'; + WHERE "Record_Type_Code" = 'AN' AND "Publication" LIKE '%Asian J. Conservation Biol%'; """, }, { "input": 'List all the accepted names linked with endemic tag.', "query": """ -SELECT DISTINCT scientificName FROM plants +SELECT DISTINCT "Scientific_Name" FROM plants JOIN ( - SELECT familyNumber, genusNumber, acceptedNameNumber + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants - WHERE recordType = 'DB'and additionalDetail2 LIKE '%Endemic%' + WHERE "Record_Type_Code" = 'DB'and "Additional_Details_2" LIKE '%Endemic%' ) b -ON plants.genusNumber = b.genusNumber AND plants.acceptedNameNumber = b.acceptedNameNumber AND plants.familyNumber = b.familyNumber -WHERE plants.recordType = 'AN'; +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; """, }, { "input": 'List all the accepted names that have no synonyms.' , "query": """ -SELECT DISTINCT a.scientificName FROM plants a -group by a.familyNumber,a.genusNumber,a.acceptedNameNumber -HAVING SUM(a.synonymNumber) = 0 AND a.acceptedNameNumber > 0; +SELECT DISTINCT a."Scientific_Name" FROM plants a +group by a."Family_Number",a."Genus_Number",a."Accepted_name_number" +HAVING SUM(a."Synonym_Number") = 0 AND a."Accepted_name_number" > 0; """, }, { "input": 'List all the accepted names authored by Roxb.', "query": """ -SELECT scientificName +SELECT "Scientific_Name" FROM plants -WHERE recordType = 'AN'and actualAuthorName LIKE '%Roxb%'; +WHERE "Record_Type_Code" = 'AN'AND "Author_Name" LIKE '%Roxb%'; """, }, { "input": 'List all genera within each family', "query": """ -SELECT familyName, genusName +SELECT "Family_Name", "Genus_Name" FROM plants -WHERE recordType = 'GE'; +WHERE "Record_Type_Code" = 'GE'; """, }, - { + { "input": 'Did Minq. discovered Cycas ryumphii?', - "query": """SELECT - CASE + "query": """SELECT + CASE WHEN EXISTS ( SELECT 1 FROM plants as a - WHERE a.scientificName = 'Cycas rumphii' - AND a.author = 'Miq.' + WHERE a."Scientific_Name" = 'Cycas rumphii' + AND a."Author_Name" = 'Miq.' ) THEN 'TRUE' ELSE 'FALSE' END AS ExistsCheck; -""", - }, -] +"""}, + + ] + + + example_selector = SemanticSimilarityExampleSelector.from_examples( + examples, + OpenAIEmbeddings(), + FAISS, + k=5, + input_keys=["input"], + ) + + + prefix_prompt = """ + You are an agent designed to interact with a SQL database. + Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. + You can order the results by a relevant column to return the most interesting examples in the database. + Never query for all the columns from a specific table, only ask for the relevant columns given the question. + You have access to tools for interacting with the database. + Only use the given tools. Only use the information returned by the tools to construct your final answer. + You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + + - Restrict your queries to the "plants" table. + - Do not return more than {top_k} rows unless specified otherwise. + - Add a limit of 25 at the end of SQL query. + - If the SQLite query returns zero rows, return a message indicating the same. + - Only refer to the data contained in the {table_info} table. Do not fabricate any data. + - For filtering based on string comparison, always use the LIKE operator and enclose the string in `%`. + - Queries on the `Additional_Details_2` column should use sub-queries involving `Family_Number`, `Genus_Number` and `Accepted_name_number`. + + Refer to the table description below for more details on the columns: + 1. **Record_Type_Code**: Contains text codes indicating the type of information in the row. + - FA: Family Name, Genus Name, Scientific Name + - TY: Type + - GE: Genus Name + - AN: Family Name (Accepted Name), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + - HB: Habit + - DB: Distribution/location of the plant + - RE: Remarks + - SN: Family Name (Synonym), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + 2. **Family_Name**: Contains the Family Name of the plant. + 3. **Genus_Name**: Contains the Genus Name of the plant. + 4. **Scientific_Name**: Contains the Scientific Name of the plant species. + 5. **Publication_Name**: Name of the journal or book where the plant discovery information is published. Use LIKE for queries. + 6. **Volume:_Page**: The volume and page number of the publication. + 7. **Year_of_Publication**: The year in which the plant information was published. + 8. **Author_Name**: May contain multiple authors separated by `&`. Use LIKE for queries. + 9. **Additional_Details**: Contains type, habit, distribution, and remarks. Use LIKE for queries. + - Type: General location information. + - Remarks: Location information about cultivation or native area. + - Distribution: Locations where the plant is common. May contain multiple locations, use LIKE for queries. + 10. **Groups**: Contains either "Gymnosperms" or "Angiosperms". + 11. **Urban**: Contains either "YES" or "NO". Specifies whether the plant is urban. + 12. **Fruit**: Contains either "YES" or "NO". Specifies whether the plant is a fruit plant. + 13. **Medicinal**: Contains either "YES" or "NO". Specifies whether the plant is medicinal. + 14. **Genus_Number**: Contains the Genus Number of the plant. + 15. **Accepted_name_number**: Contains the Accepted Name Number of the plant. + + Below are examples of questions and their corresponding SQL queries. + """ + - # Define the prompt template - example_prompt = PromptTemplate.from_template("User input: {input}\nSQL query: {query}") - prompt = FewShotPromptTemplate( - examples=examples[:5], - example_prompt=example_prompt, - prefix=system_prefix_asmita, + + agent_prompt = PromptTemplate.from_template("User input: {input}\nSQL Query: {query}") + agent_prompt_obj = FewShotPromptTemplate( + example_selector=example_selector, + example_prompt=agent_prompt, + prefix=prefix_prompt, suffix="", - input_variables=["input", "top_k", "table_info"], + input_variables=["input"], ) - # Define the full prompt structure full_prompt = ChatPromptTemplate.from_messages( [ - SystemMessagePromptTemplate(prompt=prompt), + SystemMessagePromptTemplate(prompt=agent_prompt_obj), ("human", "{input}"), MessagesPlaceholder("agent_scratchpad"), ] ) + return full_prompt - # Initialize the OpenAI models - llm = ChatOpenAI(model="gpt-4-turbo", temperature=0) +def initalize_sql_agent(): + + ### LLM + llm = ChatOpenAI(model="gpt-4o", temperature=0) - # Create the SQL agent - agent = create_sql_agent( - llm=llm, - db=db, - prompt=full_prompt, - verbose=True, - agent_type="openai-tools", - ) + ### DATABASE + db = SQLDatabase.from_uri("sqlite:///C:/Users/rohan/OneDrive/Desktop/NCSA_self_hostable_chatbot/self-hostable-ai-ta-backend/ai_ta_backend/service/plants_of_India_demo.db") + + dynamic_few_shot_prompt = get_dynamic_prompt_template() + + agent = create_sql_agent(llm, db=db, prompt=dynamic_few_shot_prompt, agent_type="openai-tools", verbose=True) return agent +def generate_response_agent(agent,user_question): + response = agent.invoke({"input": user_question}) + return response + +##### Setting up the Graph Nodes, Edges and message communication + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], operator.add] + + +@tool("plants_sql_tool", return_direct=True) +def generate_sql_query(input:str) -> str: + """Given a query looks for the three most relevant SQL sample queries""" + user_question = input + sql_agent = initalize_sql_agent() + response = generate_response_agent(sql_agent,user_question) + return response + +model = ChatOpenAI(model="gpt-4o", temperature=0) + +tools = [generate_sql_query] +tool_executor = ToolExecutor(tools) +functions = [format_tool_to_openai_function(t) for t in tools] +model = model.bind_functions(functions) + +# Define the function that determines whether to continue or not +def should_continue(state): + messages = state['messages'] + last_message = messages[-1] + # If there is no function call, then we finish + if "function_call" not in last_message.additional_kwargs: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + +# Define the function that calls the model +def call_model(state): + messages = state['messages'] + response = model.invoke(messages) + # We return a list, because this will get added to the existing list + return {"messages": [response]} + +# Define the function to execute tools +def call_tool(state): + messages = state['messages'] + # Based on the continue condition + # we know the last message involves a function call + last_message = messages[-1] + # We construct an ToolInvocation from the function_call + action = ToolInvocation( + tool=last_message.additional_kwargs["function_call"]["name"], + tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]), + ) + print(f"The agent action is {action}") + # We call the tool_executor and get back a response + response = tool_executor.invoke(action) + print(f"The tool result is: {response}") + # We use the response to create a FunctionMessage + function_message = FunctionMessage(content=str(response), name=action.tool) + # We return a list, because this will get added to the existing list + return {"messages": [function_message]} + + +from langgraph.graph import StateGraph, END +# Define a new graph +workflow = StateGraph(AgentState) + +# Define the two nodes we will cycle between +workflow.add_node("agent", call_model) +workflow.add_node("action", call_tool) + +# Set the entrypoint as `agent` where we start +workflow.set_entry_point("agent") + +# We now add a conditional edge +workflow.add_conditional_edges( + # First, we define the start node. We use `agent`. + # This means these are the edges taken after the `agent` node is called. + "agent", + # Next, we pass in the function that will determine which node is called next. + should_continue, + # Finally we pass in a mapping. + # The keys are strings, and the values are other nodes. + # END is a special node marking that the graph should finish. + # What will happen is we will call `should_continue`, and then the output of that + # will be matched against the keys in this mapping. + # Based on which one it matches, that node will then be called. + { + # If `tools`, then we call the tool node. + "continue": "action", + # Otherwise we finish. + "end": END + } +) + +# We now add a normal edge from `tools` to `agent`. +# This means that after `tools` is called, `agent` node is called next. +workflow.add_edge('action', 'agent') + +# Finally, we compile it! +# This compiles it into a LangChain Runnable, +# meaning you can use it as you would any other runnable +app = workflow.compile() + + def generate_response(user_input): - agent = initialize_agent() - output = agent.invoke(user_input) + #agent = initialize_agent() + output = app.invoke(user_input) return output From 79bba538f1b5a8800cd6605b48c2db58f1a91859 Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Thu, 20 Jun 2024 12:07:27 -0700 Subject: [PATCH 06/11] Docker improvements, trying to debug SQLAlchamy not binding to DB properly --- Dockerfile | 19 +++++++------- README.md | 7 ++++++ ai_ta_backend/database/aws.py | 25 ++++++++++++++----- ai_ta_backend/main.py | 13 ++++++++-- .../requirements.txt | 0 ai_ta_backend/service/retrieval_service.py | 9 +++++++ docker-compose.yaml | 2 +- 7 files changed, 57 insertions(+), 18 deletions(-) rename requirements.txt => ai_ta_backend/requirements.txt (100%) diff --git a/Dockerfile b/Dockerfile index 5eda60c9..b96ac3d8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,23 +4,24 @@ FROM python:3.10-slim # Set the working directory in the container WORKDIR /usr/src/app -# Copy the local directory contents into the container -COPY . . + +# Copy the requirements file first to leverage Docker cache +COPY ai_ta_backend/requirements.txt . # Install any needed packages specified in requirements.txt -RUN pip install --no-cache-dir -r requirements.txt +RUN pip install -r requirements.txt + +# Mkdir for sqlite db +RUN mkdir -p /usr/src/app/db + +# Copy the rest of the local directory contents into the container +COPY . . # Set the Python path to include the ai_ta_backend directory ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/ai_ta_backend" -RUN echo $PYTHONPATH -RUN ls -la /usr/src/app/ - # Make port 8000 available to the world outside this container EXPOSE 8000 -# Define environment variable for Gunicorn to bind to 0.0.0.0:8000 -ENV GUNICORN_CMD_ARGS="--bind=0.0.0.0:8000" - # Run the application using Gunicorn with specified configuration CMD ["gunicorn", "--workers=1", "--threads=100", "--worker-class=gthread", "ai_ta_backend.main:app", "--timeout=1800", "--bind=0.0.0.0:8000"] diff --git a/README.md b/README.md index 149e65ee..a049e2b1 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,12 @@ Architecture diagram of Flask + Next.js & React hosted on Vercel. Automatic [API Reference](https://uiuc-chatbot.github.io/ai-ta-backend/reference/) +## Docker Deployment + +1. Build flask image `docker build -t kastanday/ai-ta-backend:gunicorn .` +2. Push flask image `docker push kastanday/ai-ta-backend:gunicorn` +3. Run docker compose `docker compose up` + ## 📣 Development 1. Rename `.env.template` to `.env` and fill in the required variables @@ -36,3 +42,4 @@ The docs are auto-built and deployed to [our docs website](https://uiuc-chatbot. 'url': doc.metadata.get('url'), # wouldn't this error out? 'base_url': doc.metadata.get('base_url'), ``` + diff --git a/ai_ta_backend/database/aws.py b/ai_ta_backend/database/aws.py index 1e6f397d..c0fbc450 100644 --- a/ai_ta_backend/database/aws.py +++ b/ai_ta_backend/database/aws.py @@ -8,12 +8,25 @@ class AWSStorage(): @inject def __init__(self): - # S3 - self.s3_client = boto3.client( - 's3', - aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], - aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], - ) + if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]): + print("Using AWS for storage") + self.s3_client = boto3.client( + 's3', + aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), + aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), + ) + elif all(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): + print("Using Minio for storage") + self.s3_client = boto3.client( + 's3', + endpoint_url=os.getenv('MINIO_URL'), + aws_access_key_id=os.getenv('MINIO_ACCESS_KEY'), + aws_secret_access_key=os.getenv('MINIO_SECRET_KEY'), + config=boto3.session.Config(signature_version='s3v4'), + region_name='us-east-1' + ) + else: + raise ValueError("No valid storage credentials found.") def upload_file(self, file_path: str, bucket_name: str, object_name: str): self.s3_client.upload_file(file_path, bucket_name, object_name) diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 5175ee70..cb290308 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,5 +1,6 @@ import json import logging +import sys import os import time from typing import List @@ -44,14 +45,20 @@ from ai_ta_backend.service.workflow_service import WorkflowService from ai_ta_backend.extensions import db + +# Make docker log our prints() -- Set PYTHONUNBUFFERED to ensure no output buffering +os.environ['PYTHONUNBUFFERED'] = '1' +sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 1) +sys.stderr = os.fdopen(sys.stderr.fileno(), 'w', 1) + app = Flask(__name__) CORS(app) executor = Executor(app) # app.config['EXECUTOR_MAX_WORKERS'] = 5 nothing == picks defaults for me -#app.config['SERVER_TIMEOUT'] = 1000 # seconds +# app.config['SERVER_TIMEOUT'] = 1000 # seconds # load API keys from globally-availabe .env file -load_dotenv() +load_dotenv(override=True) @app.route('/') def index() -> Response: @@ -131,6 +138,7 @@ def getTopContexts(service: RetrievalService) -> Response: def getAll(service: RetrievalService) -> Response: """Get all course materials based on the course_name """ + print("In getAll()") course_name: List[str] | str = request.args.get('course_name', default='', type=str) if course_name == '': @@ -494,6 +502,7 @@ def configure(binder: Binder) -> None: db.init_app(app) db.create_all() binder.bind(SQLAlchemyDatabase, to=db, scope=SingletonScope) + print("Bound to SQL DB!") sql_bound = True break diff --git a/requirements.txt b/ai_ta_backend/requirements.txt similarity index 100% rename from requirements.txt rename to ai_ta_backend/requirements.txt diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index beb87a89..822ccaaa 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -33,6 +33,14 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLAlchemyDatabase, aws: AWSStora self.posthog = posthog self.nomicService = nomicService + print("self.sqlDb", self.sqlDb) + print("Attributes of sqlDb:", dir(self.sqlDb)) # Print all attributes of sqlDb + + try: + assert hasattr(self.sqlDb, 'getAllMaterialsForCourse'), "BAD BAD -- sqlDb does not have the method getAllMaterialsForCourse" + except Exception as e: + print("Error loading getAllMaterialsForCourse: ", e) + openai.api_key = os.environ["OPENAI_API_KEY"] self.embeddings = OpenAIEmbeddings( @@ -138,6 +146,7 @@ def getAll( Returns: list of dictionaries with distinct s3 path, readable_filename and course_name, url, base_url. """ + print("Inside retrieval Service getAll()") response = self.sqlDb.getAllMaterialsForCourse(course_name) diff --git a/docker-compose.yaml b/docker-compose.yaml index 33ae37ee..9aa5d1e2 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -41,7 +41,7 @@ services: ports: - "8000:8000" volumes: - - ./db:/app/db # Mount local directory to store SQLite database + - ./db:/usr/src/app/db # Mount local directory to store SQLite database depends_on: # - sqlite - redis From 88b37c6e20cc7e7b01d1455f89c7933ae766f487 Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Thu, 20 Jun 2024 12:17:07 -0700 Subject: [PATCH 07/11] Fix SQLAlchemy error, proper bind --- ai_ta_backend/main.py | 88 ++++++++++++---------- ai_ta_backend/service/retrieval_service.py | 17 +---- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index cb290308..4a302515 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,7 +1,7 @@ import json import logging -import sys import os +import sys import time from typing import List @@ -19,10 +19,10 @@ from flask_executor import Executor from flask_injector import FlaskInjector, RequestScope from injector import Binder, SingletonScope -from ai_ta_backend.database.sql import SQLAlchemyDatabase + from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.qdrant import VectorDatabase - +from ai_ta_backend.database.sql import SQLAlchemyDatabase from ai_ta_backend.executors.flask_executor import ( ExecutorInterface, FlaskExecutorAdapter, @@ -35,16 +35,13 @@ ThreadPoolExecutorAdapter, ThreadPoolExecutorInterface, ) +from ai_ta_backend.extensions import db from ai_ta_backend.service.export_service import ExportService from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.retrieval_service import RetrievalService from ai_ta_backend.service.sentry_service import SentryService - -# from ai_ta_backend.beam.nomic_logging import create_document_map from ai_ta_backend.service.workflow_service import WorkflowService -from ai_ta_backend.extensions import db - # Make docker log our prints() -- Set PYTHONUNBUFFERED to ensure no output buffering os.environ['PYTHONUNBUFFERED'] = '1' @@ -60,6 +57,7 @@ # load API keys from globally-availabe .env file load_dotenv(override=True) + @app.route('/') def index() -> Response: """_summary_ @@ -214,6 +212,7 @@ def nomic_map(service: NomicService): # response.headers.add('Access-Control-Allow-Origin', '*') # return response + @app.route('/createConversationMap', methods=['GET']) def createConversationMap(service: NomicService): course_name: str = request.args.get('course_name', default='', type=str) @@ -228,6 +227,7 @@ def createConversationMap(service: NomicService): response.headers.add('Access-Control-Allow-Origin', '*') return response + @app.route('/logToConversationMap', methods=['GET']) def logToConversationMap(service: NomicService, flaskExecutor: ExecutorInterface): course_name: str = request.args.get('course_name', default='', type=str) @@ -297,6 +297,7 @@ def export_convo_history(service: ExportService): return response + @app.route('/export-conversations-custom', methods=['GET']) def export_conversations_custom(service: ExportService): course_name: str = request.args.get('course_name', default='', type=str) @@ -389,6 +390,7 @@ def getTopContextsWithMQR(service: RetrievalService, posthog_service: PosthogSer response.headers.add('Access-Control-Allow-Origin', '*') return response + @app.route('/getworkflows', methods=['GET']) def get_all_workflows(service: WorkflowService) -> Response: """ @@ -404,7 +406,6 @@ def get_all_workflows(service: WorkflowService) -> Response: print("In get_all_workflows.. api_key: ", api_key) - # if no API Key, return empty set. # if api_key == '': # # proper web error "400 Bad request" @@ -488,61 +489,69 @@ def configure(binder: Binder) -> None: # Define database URLs with conditional checks for environment variables DB_URLS = { - 'supabase': f"postgresql://{os.getenv('SUPABASE_KEY')}@{os.getenv('SUPABASE_URL')}" if os.getenv('SUPABASE_KEY') and os.getenv('SUPABASE_URL') else None, - 'sqlite': f"sqlite:///{os.getenv('SQLITE_DB_NAME')}" if os.getenv('SQLITE_DB_NAME') else None, - 'postgres': f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@{os.getenv('POSTGRES_URL')}" if os.getenv('POSTGRES_USER') and os.getenv('POSTGRES_PASSWORD') and os.getenv('POSTGRES_URL') else None + 'supabase': + f"postgresql://{os.getenv('SUPABASE_KEY')}@{os.getenv('SUPABASE_URL')}" + if os.getenv('SUPABASE_KEY') and os.getenv('SUPABASE_URL') else None, + 'sqlite': + f"sqlite:///{os.getenv('SQLITE_DB_NAME')}" if os.getenv('SQLITE_DB_NAME') else None, + 'postgres': + f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@{os.getenv('POSTGRES_URL')}" + if os.getenv('POSTGRES_USER') and os.getenv('POSTGRES_PASSWORD') and os.getenv('POSTGRES_URL') else None } # Bind to the first available SQL database configuration for db_type, url in DB_URLS.items(): - if url: - logging.info(f"Binding to {db_type} database with URL: {url}") - with app.app_context(): - app.config['SQLALCHEMY_DATABASE_URI'] = url - db.init_app(app) - db.create_all() - binder.bind(SQLAlchemyDatabase, to=db, scope=SingletonScope) - print("Bound to SQL DB!") - sql_bound = True - break - + if url: + logging.info(f"Binding to {db_type} database with URL: {url}") + with app.app_context(): + app.config['SQLALCHEMY_DATABASE_URI'] = url + db.init_app(app) + db.create_all() + binder.bind(SQLAlchemyDatabase, to=SQLAlchemyDatabase(db), scope=SingletonScope) + print("Bound to SQL DB!") + sql_bound = True + break + # Conditionally bind databases based on the availability of their respective secrets - if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any(os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): + if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any( + os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): logging.info("Binding to Qdrant database") binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) vector_bound = True - if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_BUCKET_NAME"]) or any(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): - if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"): logging.info("Binding to AWS storage") - elif os.getenv("MINIO_ACCESS_KEY") and os.getenv("MINIO_SECRET_KEY"): logging.info("Binding to Minio storage") + if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "S3_BUCKET_NAME"]) or any( + os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): + if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"): + logging.info("Binding to AWS storage") + elif os.getenv("MINIO_ACCESS_KEY") and os.getenv("MINIO_SECRET_KEY"): + logging.info("Binding to Minio storage") binder.bind(AWSStorage, to=AWSStorage, scope=SingletonScope) storage_bound = True - # Conditionally bind services based on the availability of their respective secrets if os.getenv("NOMIC_API_KEY"): - logging.info("Binding to Nomic service") - binder.bind(NomicService, to=NomicService, scope=SingletonScope) + logging.info("Binding to Nomic service") + binder.bind(NomicService, to=NomicService, scope=SingletonScope) if os.getenv("POSTHOG_API_KEY"): - logging.info("Binding to Posthog service") - binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) + logging.info("Binding to Posthog service") + binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) if os.getenv("SENTRY_DSN"): - logging.info("Binding to Sentry service") - binder.bind(SentryService, to=SentryService, scope=SingletonScope) + logging.info("Binding to Sentry service") + binder.bind(SentryService, to=SentryService, scope=SingletonScope) if os.getenv("EMAIL_SENDER"): - logging.info("Binding to Export service") - binder.bind(ExportService, to=ExportService, scope=SingletonScope) + logging.info("Binding to Export service") + binder.bind(ExportService, to=ExportService, scope=SingletonScope) if os.getenv("N8N_URL"): - logging.info("Binding to Workflow service") - binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) + logging.info("Binding to Workflow service") + binder.bind(WorkflowService, to=WorkflowService, scope=SingletonScope) if vector_bound and sql_bound and storage_bound: - logging.info("Binding to Retrieval service") - binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) + logging.info("Binding to Retrieval service") + binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) # Always bind the executor and its adapters binder.bind(ExecutorInterface, to=FlaskExecutorAdapter(executor), scope=SingletonScope) @@ -550,6 +559,7 @@ def configure(binder: Binder) -> None: binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter, scope=SingletonScope) logging.info("Configured all services and adapters", binder._bindings) + FlaskInjector(app=app, modules=[configure]) if __name__ == '__main__': diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 822ccaaa..3ff0f5a5 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -9,9 +9,10 @@ from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.schema import Document -from ai_ta_backend.database.sql import SQLAlchemyDatabase + from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.qdrant import VectorDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase from ai_ta_backend.service.nomic_service import NomicService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.sentry_service import SentryService @@ -33,14 +34,6 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLAlchemyDatabase, aws: AWSStora self.posthog = posthog self.nomicService = nomicService - print("self.sqlDb", self.sqlDb) - print("Attributes of sqlDb:", dir(self.sqlDb)) # Print all attributes of sqlDb - - try: - assert hasattr(self.sqlDb, 'getAllMaterialsForCourse'), "BAD BAD -- sqlDb does not have the method getAllMaterialsForCourse" - except Exception as e: - print("Error loading getAllMaterialsForCourse: ", e) - openai.api_key = os.environ["OPENAI_API_KEY"] self.embeddings = OpenAIEmbeddings( @@ -146,8 +139,6 @@ def getAll( Returns: list of dictionaries with distinct s3 path, readable_filename and course_name, url, base_url. """ - print("Inside retrieval Service getAll()") - response = self.sqlDb.getAllMaterialsForCourse(course_name) data = response.data @@ -347,7 +338,7 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, # delete from Nomic response = self.sqlDb.getProjectsMapForCourse(course_name) - data, count = response.data, response.count + data, _count = response.data, response.count if not data: raise Exception(f"No document map found for this course: {course_name}") project_id = str(data[0].doc_map_id) @@ -441,7 +432,7 @@ def _capture_search_succeeded_event(self, search_query, course_name, search_resu "max_vector_score": max_vector_score, "min_vector_score": min_vector_score, "avg_vector_score": avg_vector_score, - "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, + "vector_score_calculation_latency_sec": time.monotonic() - vector_score_calc_latency_sec, }, ) From cd297d3d784b12612585b21c40d68bd9490c9005 Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Thu, 20 Jun 2024 12:20:59 -0700 Subject: [PATCH 08/11] FULL TRUNK FORMAT --- .trunk/configs/.isort.cfg | 2 +- .trunk/configs/.style.yapf | 2 +- .trunk/trunk.yaml | 1 - ai_ta_backend/beam/OpenaiEmbeddings.py | 15 +- ai_ta_backend/beam/ingest.py | 7 +- ai_ta_backend/beam/nomic_logging.py | 47 +- ai_ta_backend/database/aws.py | 37 +- ai_ta_backend/database/qdrant.py | 18 +- ai_ta_backend/database/sql.py | 447 +++++++++--------- ai_ta_backend/database/supabase.py | 89 ++-- .../executors/process_pool_executor.py | 3 +- ai_ta_backend/extensions.py | 3 +- ai_ta_backend/main.py | 65 ++- ai_ta_backend/modal/pest_detection.py | 20 +- ai_ta_backend/model/models.py | 148 +++--- ai_ta_backend/model/response.py | 11 +- ai_ta_backend/service/export_service.py | 32 +- ai_ta_backend/service/nomic_service.py | 47 +- ai_ta_backend/service/retrieval_service.py | 11 +- ai_ta_backend/service/sentry_service.py | 2 +- ai_ta_backend/service/workflow_service.py | 16 +- .../utils/context_parent_doc_padding.py | 11 +- ai_ta_backend/utils/emails.py | 4 +- ai_ta_backend/utils/utils_tokenization.py | 11 +- docker-compose.yaml | 6 +- railway.json | 7 +- 26 files changed, 507 insertions(+), 555 deletions(-) diff --git a/.trunk/configs/.isort.cfg b/.trunk/configs/.isort.cfg index b9fb3f3e..5225d7a2 100644 --- a/.trunk/configs/.isort.cfg +++ b/.trunk/configs/.isort.cfg @@ -1,2 +1,2 @@ [settings] -profile=black +profile=google diff --git a/.trunk/configs/.style.yapf b/.trunk/configs/.style.yapf index 3d4d13b2..3e0faa55 100644 --- a/.trunk/configs/.style.yapf +++ b/.trunk/configs/.style.yapf @@ -1,4 +1,4 @@ [style] based_on_style = google -column_limit = 120 +column_limit = 140 indent_width = 2 diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 4186a1e2..292c526c 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -43,7 +43,6 @@ lint: paths: - .github/**/* - .trunk/**/* - - mkdocs.yml - .DS_Store - .vscode/**/* - README.md diff --git a/ai_ta_backend/beam/OpenaiEmbeddings.py b/ai_ta_backend/beam/OpenaiEmbeddings.py index 6c8239f2..b3a088c2 100644 --- a/ai_ta_backend/beam/OpenaiEmbeddings.py +++ b/ai_ta_backend/beam/OpenaiEmbeddings.py @@ -114,7 +114,6 @@ # # from langchain.vectorstores import Qdrant # # from qdrant_client import QdrantClient, models - # class OpenAIAPIProcessor: # def __init__(self, input_prompts_list, request_url, api_key, max_requests_per_minute, max_tokens_per_minute, @@ -263,7 +262,6 @@ # self.cleaned_results: List[str] = extract_context_from_results(self.results) - # def extract_context_from_results(results: List[Any]) -> List[str]: # assistant_contents = [] # total_prompt_tokens = 0 @@ -282,10 +280,8 @@ # return assistant_contents - # # dataclasses - # @dataclass # class StatusTracker: # """Stores metadata about the script's progress. Only one instance is created.""" @@ -299,7 +295,6 @@ # num_other_errors: int = 0 # time_of_last_rate_limit_error: float = 0 # used to cool off after hitting rate limits - # @dataclass # class APIRequest: # """Stores an API request's inputs, outputs, and other metadata. Contains a method to make an API call.""" @@ -360,10 +355,8 @@ # return data - # # functions - # def api_endpoint_from_url(request_url: str): # """Extract the API endpoint from the request URL.""" # if 'text-embedding-ada-002' in request_url: @@ -372,14 +365,12 @@ # match = re.search('^https://[^/]+/v\\d+/(.+)$', request_url) # return match[1] # type: ignore - # def append_to_jsonl(data, filename: str) -> None: # """Append a json payload to the end of a jsonl file.""" # json_string = json.dumps(data) # with open(filename, "a") as f: # f.write(json_string + "\n") - # def num_tokens_consumed_from_request( # request_json: dict, # api_endpoint: str, @@ -432,7 +423,6 @@ # else: # raise NotImplementedError(f'API endpoint "{api_endpoint}" not implemented in this script') - # def task_id_generator_function(): # """Generate integers 0, 1, 2, and so on.""" # task_id = 0 @@ -440,7 +430,6 @@ # yield task_id # task_id += 1 - # if __name__ == '__main__': # pass @@ -520,10 +509,10 @@ # # print("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") # ''' # Pricing: -# GPT4: +# GPT4: # * $0.03 prompt # * $0.06 completions -# 3.5-turbo: +# 3.5-turbo: # * $0.0015 prompt # * $0.002 completions # ''' diff --git a/ai_ta_backend/beam/ingest.py b/ai_ta_backend/beam/ingest.py index b31bab87..aafe814d 100644 --- a/ai_ta_backend/beam/ingest.py +++ b/ai_ta_backend/beam/ingest.py @@ -94,7 +94,6 @@ # # MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") # OPENAI_API_TYPE = "azure" # "openai" or "azure" - # def loader(): # """ # The loader function will run once for each worker that starts up. https://docs.beam.cloud/deployment/loaders @@ -141,11 +140,9 @@ # return qdrant_client, vectorstore, s3_client, supabase_client, posthog - # # autoscaler = RequestLatencyAutoscaler(desired_latency=30, max_replicas=2) # autoscaler = QueueDepthAutoscaler(max_tasks_per_replica=300, max_replicas=3) - # # Triggers determine how your app is deployed # # @app.rest_api( # @app.task_queue( @@ -231,7 +228,6 @@ # print(f"Final success_fail_dict: {success_fail_dict}") # return json.dumps(success_fail_dict) - # class Ingest(): # def __init__(self, qdrant_client, vectorstore, s3_client, supabase_client, posthog): @@ -843,7 +839,7 @@ # text = pytesseract.image_to_string(im.original) # print("Page number: ", i, "Text: ", text[:100]) # pdf_pages_OCRed.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) - + # metadatas: List[Dict[str, Any]] = [ # { # 'course_name': course_name, @@ -1364,7 +1360,6 @@ # # return all_files - # if __name__ == "__main__": # raise NotImplementedError("This file is not meant to be run directly") # text = "Testing 123" diff --git a/ai_ta_backend/beam/nomic_logging.py b/ai_ta_backend/beam/nomic_logging.py index d15c616e..2b008e99 100644 --- a/ai_ta_backend/beam/nomic_logging.py +++ b/ai_ta_backend/beam/nomic_logging.py @@ -35,7 +35,7 @@ # """ # print("in create_document_map()") # nomic.login(os.getenv('NOMIC_API_KEY')) - + # try: # # check if map exists # response = SUPABASE_CLIENT.table("projects").select("doc_map_id").eq("course_name", course_name).execute() @@ -50,7 +50,7 @@ # desc=False).execute() # if not response.count: # return "No documents found for this course." - + # total_doc_count = response.count # print("Total number of documents in Supabase: ", total_doc_count) @@ -59,7 +59,7 @@ # return "Cannot create a map because there are less than 20 documents in the course." # first_id = response.data[0]['id'] - + # combined_dfs = [] # curr_total_doc_count = 0 # doc_count = 0 @@ -67,7 +67,7 @@ # # iteratively query in batches of 25 # while curr_total_doc_count < total_doc_count: - + # response = SUPABASE_CLIENT.table("documents").select( # "id, created_at, s3_path, url, base_url, readable_filename, contexts").eq("course_name", course_name).gte( # 'id', first_id).order('id', desc=False).limit(25).execute() @@ -93,7 +93,7 @@ # topic_label_field = "text" # colorable_fields = ["readable_filename", "text", "base_url", "created_at"] # result = create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) - + # if result == "success": # # update flag # first_batch = False @@ -109,7 +109,6 @@ # else: # insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() # print("Insert Response from supabase: ", insert_response) - # else: # # append to existing map @@ -123,7 +122,7 @@ # info = {'last_uploaded_doc_id': last_id} # update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() # print("Response from supabase: ", update_response) - + # # reset variables # combined_dfs = [] # doc_count = 0 @@ -160,11 +159,10 @@ # else: # insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() # print("Insert Response from supabase: ", insert_response) - - + # # rebuild the map # rebuild_map(course_name, "document") - + # except Exception as e: # print(e) # sentry_sdk.capture_exception(e) @@ -201,10 +199,9 @@ # sentry_sdk.capture_exception(e) # return "Error in deleting from document map: {e}" - # def log_to_document_map(course_name: str): # """ -# This is a function which appends new documents to an existing document map. It's called +# This is a function which appends new documents to an existing document map. It's called # at the end of split_and_upload() after inserting data to Supabase. # Args: # data: dict - the response data from Supabase insertion @@ -227,14 +224,14 @@ # # create a map # create_document_map(course_name) # return "Document map not present, triggering map creation." - + # project = AtlasProject(project_id=project_id, add_datums_if_exists=True) # project_name = "Document Map for " + course_name - + # # check if project is LOCKED, if yes -> skip logging # if not project.is_accepting_data: # return "Skipping Nomic logging because project is locked." - + # # fetch count of records greater than last_uploaded_doc_id # print("last uploaded doc id: ", last_uploaded_doc_id) # response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() @@ -262,7 +259,7 @@ # # append to existing map # print("Appending data to existing map...") - + # result = append_to_map(embeddings, metadata, project_name) # if result == "success": # # update the last uploaded id in supabase @@ -270,15 +267,15 @@ # info = {'last_uploaded_doc_id': last_id} # update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() # print("Response from supabase: ", update_response) - + # # reset variables # combined_dfs = [] # doc_count = 0 # print("Records uploaded: ", current_doc_count) - + # # set first_id for next iteration # first_id = response.data[-1]['id'] + 1 - + # # upload last set of docs # if doc_count > 0: # final_df = pd.concat(combined_dfs, ignore_index=True) @@ -292,13 +289,12 @@ # project_info = {'last_uploaded_doc_id': last_id} # update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() # print("Response from supabase: ", update_response) - + # return "success" # except Exception as e: # print(e) -# return "failed" - - +# return "failed" + # def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): # """ # Generic function to create a Nomic map from given parameters. @@ -380,7 +376,7 @@ # "created_at": created_at, # "s3_path": row['s3_path'], # "url": row['url'], -# "base_url": row['base_url'], +# "base_url": row['base_url'], # "readable_filename": row['readable_filename'], # "modified_at": current_time, # "text": text_row @@ -431,8 +427,5 @@ # sentry_sdk.capture_exception(e) # return "Error in rebuilding map: {e}" - - # if __name__ == '__main__': # pass - diff --git a/ai_ta_backend/database/aws.py b/ai_ta_backend/database/aws.py index c0fbc450..58fb1bb3 100644 --- a/ai_ta_backend/database/aws.py +++ b/ai_ta_backend/database/aws.py @@ -9,24 +9,22 @@ class AWSStorage(): @inject def __init__(self): if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]): - print("Using AWS for storage") - self.s3_client = boto3.client( - 's3', - aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), - aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), - ) + print("Using AWS for storage") + self.s3_client = boto3.client( + 's3', + aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), + aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), + ) elif all(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): - print("Using Minio for storage") - self.s3_client = boto3.client( - 's3', - endpoint_url=os.getenv('MINIO_URL'), - aws_access_key_id=os.getenv('MINIO_ACCESS_KEY'), - aws_secret_access_key=os.getenv('MINIO_SECRET_KEY'), - config=boto3.session.Config(signature_version='s3v4'), - region_name='us-east-1' - ) + print("Using Minio for storage") + self.s3_client = boto3.client('s3', + endpoint_url=os.getenv('MINIO_URL'), + aws_access_key_id=os.getenv('MINIO_ACCESS_KEY'), + aws_secret_access_key=os.getenv('MINIO_SECRET_KEY'), + config=boto3.session.Config(signature_version='s3v4'), + region_name='us-east-1') else: - raise ValueError("No valid storage credentials found.") + raise ValueError("No valid storage credentials found.") def upload_file(self, file_path: str, bucket_name: str, object_name: str): self.s3_client.upload_file(file_path, bucket_name, object_name) @@ -39,9 +37,4 @@ def delete_file(self, bucket_name: str, s3_path: str): def generatePresignedUrl(self, object: str, bucket_name: str, s3_path: str, expiration: int = 3600): # generate presigned URL - return self.s3_client.generate_presigned_url('get_object', - Params={ - 'Bucket': bucket_name, - 'Key': s3_path - }, - ExpiresIn=expiration) + return self.s3_client.generate_presigned_url('get_object', Params={'Bucket': bucket_name, 'Key': s3_path}, ExpiresIn=expiration) diff --git a/ai_ta_backend/database/qdrant.py b/ai_ta_backend/database/qdrant.py index 725cb8fc..f2826e68 100644 --- a/ai_ta_backend/database/qdrant.py +++ b/ai_ta_backend/database/qdrant.py @@ -4,8 +4,8 @@ from injector import inject from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import Qdrant -from qdrant_client import QdrantClient, models - +from qdrant_client import models +from qdrant_client import QdrantClient OPENAI_API_TYPE = "azure" # "openai" or "azure" @@ -38,11 +38,11 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q Search the vector database for a given query. """ must_conditions = self._create_search_conditions(course_name, doc_groups) - + # Filter for the must_conditions myfilter = models.Filter(must=must_conditions) print(f"Filter: {myfilter}") - + # Search the vector database search_results = self.qdrant_client.search( collection_name=os.environ['QDRANT_COLLECTION_NAME'], @@ -58,20 +58,18 @@ def _create_search_conditions(self, course_name, doc_groups: List[str]): """ Create search conditions for the vector search. """ - must_conditions: list[models.Condition] = [ - models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name)) - ] - + must_conditions: list[models.Condition] = [models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name))] + if doc_groups and 'All Documents' not in doc_groups: # Final combined condition combined_condition = None # Condition for matching any of the specified doc_groups match_any_condition = models.FieldCondition(key='doc_groups', match=models.MatchAny(any=doc_groups)) combined_condition = models.Filter(should=[match_any_condition]) - + # Add the combined condition to the must_conditions list must_conditions.append(combined_condition) - + return must_conditions def delete_data(self, collection_name: str, key: str, value: str): diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index f1474a9f..d58dc618 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -1,230 +1,231 @@ +import logging from typing import List -from injector import inject + from flask_sqlalchemy import SQLAlchemy -import ai_ta_backend.model.models as models -import logging +from injector import inject +import ai_ta_backend.model.models as models from ai_ta_backend.model.response import DatabaseResponse + class SQLAlchemyDatabase: - @inject - def __init__(self, db: SQLAlchemy): - logging.info("Initializing SQLAlchemyDatabase") - self.db = db - # Ensure an app context is pushed (Flask-Injector will handle this) - # with current_app.app_context(): - # db.init_app(current_app) - # db.create_all() # Create tables - - def getAllMaterialsForCourse(self, course_name: str): - try: - query = self.db.select(models.Document).where(models.Document.course_name == course_name) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Document] = [doc for doc in result] - return DatabaseResponse[models.Document](data=documents, count=len(result)) - finally: - self.db.session.close() - - def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - try: - query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Document] = [doc for doc in result] - return DatabaseResponse[models.Document](data=documents, count=len(result)) - finally: - self.db.session.close() - - def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - try: - query = self.db.select(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Document] = [doc for doc in result] - return DatabaseResponse[models.Document](data=documents, count=len(result)) - finally: - self.db.session.close() - - def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - try: - query = self.db.delete(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) - self.db.session.execute(query) - self.db.session.commit() - finally: - self.db.session.close() - - def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - try: - query = self.db.delete(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) - self.db.session.execute(query) - self.db.session.commit() - finally: - self.db.session.close() - - def getProjectsMapForCourse(self, course_name: str): - try: - query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) - result = self.db.session.execute(query).scalars().all() - projects: List[models.Project] = [doc for doc in result] - return DatabaseResponse[models.Project](data=projects, count=len(result)) - finally: - self.db.session.close() - - def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str): - try: - query = self.db.select(models.Document).where(models.Document.course_name == course_name) - if from_date: - query = query.filter(models.Document.created_at >= from_date) - if to_date: - query = query.filter(models.Document.created_at <= to_date) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Document] = [doc for doc in result] - return DatabaseResponse[models.Document](data=documents, count=len(result)) - finally: - self.db.session.close() - - def getConversationsBetweenDates(self, course_name: str, from_date: str, to_date: str): - try: - query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) - if from_date: - query = query.filter(models.LlmConvoMonitor.created_at >= from_date) - if to_date: - query = query.filter(models.LlmConvoMonitor.created_at <= to_date) - result = self.db.session.execute(query).scalars().all() - documents: List[models.LlmConvoMonitor] = [doc for doc in result] - return DatabaseResponse[models.LlmConvoMonitor](data=documents, count=len(result)) - finally: - self.db.session.close() - - def getAllDocumentsForDownload(self, course_name: str, first_id: int): - try: - query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.id >= first_id) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Document] = [doc for doc in result] - return DatabaseResponse[models.Document](data=documents, count=len(result)) - finally: - self.db.session.close() - - def getAllConversationsForDownload(self, course_name: str, first_id: int): - try: - query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id >= first_id) - result = self.db.session.execute(query).scalars().all() - conversations: List[models.LlmConvoMonitor] = [doc for doc in result] - return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) - finally: - self.db.session.close() - - def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50): - try: - query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id > first_id) - if last_id != 0: - query = query.filter(models.LlmConvoMonitor.id <= last_id) - query = query.limit(limit) - result = self.db.session.execute(query).scalars().all() - conversations: List[models.LlmConvoMonitor] = [doc for doc in result] - return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) - finally: - self.db.session.close() - - def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): - try: - fields_to_select = [getattr(models.Document, field) for field in fields.split(", ")] - query = self.db.select(*fields_to_select).where(models.Document.course_name == course_name, models.Document.id >= first_id).limit(limit) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Document] = [doc for doc in result] - return DatabaseResponse[models.Document](data=documents, count=len(result)) - finally: - self.db.session.close() - - def insertProjectInfo(self, project_info): - try: - self.db.session.execute(self.db.insert(models.Project).values(**project_info)) - self.db.session.commit() - finally: - self.db.session.close() - - def getAllFromLLMConvoMonitor(self, course_name: str): - try: - query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) - result = self.db.session.execute(query).scalars().all() - conversations: List[models.LlmConvoMonitor] = [doc for doc in result] - return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) - finally: - self.db.session.close() - - def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int): - try: - query = self.db.select(models.LlmConvoMonitor.id).where(models.LlmConvoMonitor.course_name == course_name) - if last_id != 0: - query = query.filter(models.LlmConvoMonitor.id > last_id) - count_query = self.db.select(self.db.func.count()).select_from(query.subquery()) - count = self.db.session.execute(count_query).scalar() - return DatabaseResponse[models.LlmConvoMonitor](data=[], count=1) - finally: - self.db.session.close() - - def getDocMapFromProjects(self, course_name: str): - try: - query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) - result = self.db.session.execute(query).scalars().all() - documents: List[models.Project] = [doc for doc in result] - return DatabaseResponse[models.Project](data=documents, count=len(result)) - finally: - self.db.session.close() - - def getConvoMapFromProjects(self, course_name: str): - try: - query = self.db.select(models.Project).where(models.Project.course_name == course_name) - result = self.db.session.execute(query).scalars().all() - conversations: List[models.Project] = [doc for doc in result] - return DatabaseResponse[models.Project](data=conversations, count=len(result)) - finally: - self.db.session.close() - - def updateProjects(self, course_name: str, data: dict): - try: - query = self.db.update(models.Project).where(models.Project.course_name == course_name).values(**data) - self.db.session.execute(query) - self.db.session.commit() - finally: - self.db.session.close() - - def getLatestWorkflowId(self): - try: - query = self.db.select(models.N8nWorkflows.latest_workflow_id) - result = self.db.session.execute(query).fetchone() - return result - finally: - self.db.session.close() - - def lockWorkflow(self, id: str): - try: - new_workflow = models.N8nWorkflows(is_locked=True) - self.db.session.add(new_workflow) - self.db.session.commit() - finally: - self.db.session.close() - - def deleteLatestWorkflowId(self, id: str): - try: - query = self.db.delete(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id) - self.db.session.execute(query) - self.db.session.commit() - finally: - self.db.session.close() - - def unlockWorkflow(self, id: str): - try: - query = self.db.update(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id).values(is_locked=False) - self.db.session.execute(query) - self.db.session.commit() - finally: - self.db.session.close() - - def getConversation(self, course_name: str, key: str, value: str): - try: - query = self.db.select(models.LlmConvoMonitor).where(getattr(models.LlmConvoMonitor, key) == value) - result = self.db.session.execute(query).scalars().all() - conversations: List[models.LlmConvoMonitor] = [doc for doc in result] - return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) - finally: - self.db.session.close() \ No newline at end of file + @inject + def __init__(self, db: SQLAlchemy): + logging.info("Initializing SQLAlchemyDatabase") + self.db = db + + def getAllMaterialsForCourse(self, course_name: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): + try: + query = self.db.delete(models.Document).where(models.Document.course_name == course_name, getattr(models.Document, key) == value) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): + try: + query = self.db.delete(models.Document).where(models.Document.course_name == course_name, models.Document.s3_path == s3_path) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def getProjectsMapForCourse(self, course_name: str): + try: + query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + projects: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=projects, count=len(result)) + finally: + self.db.session.close() + + def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name) + if from_date: + query = query.filter(models.Document.created_at >= from_date) + if to_date: + query = query.filter(models.Document.created_at <= to_date) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getConversationsBetweenDates(self, course_name: str, from_date: str, to_date: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) + if from_date: + query = query.filter(models.LlmConvoMonitor.created_at >= from_date) + if to_date: + query = query.filter(models.LlmConvoMonitor.created_at <= to_date) + result = self.db.session.execute(query).scalars().all() + documents: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getAllDocumentsForDownload(self, course_name: str, first_id: int): + try: + query = self.db.select(models.Document).where(models.Document.course_name == course_name, models.Document.id >= first_id) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getAllConversationsForDownload(self, course_name: str, first_id: int): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id + >= first_id) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name, models.LlmConvoMonitor.id + > first_id) + if last_id != 0: + query = query.filter(models.LlmConvoMonitor.id <= last_id) + query = query.limit(limit) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): + try: + fields_to_select = [getattr(models.Document, field) for field in fields.split(", ")] + query = self.db.select(*fields_to_select).where(models.Document.course_name == course_name, models.Document.id + >= first_id).limit(limit) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Document] = [doc for doc in result] + return DatabaseResponse[models.Document](data=documents, count=len(result)) + finally: + self.db.session.close() + + def insertProjectInfo(self, project_info): + try: + self.db.session.execute(self.db.insert(models.Project).values(**project_info)) + self.db.session.commit() + finally: + self.db.session.close() + + def getAllFromLLMConvoMonitor(self, course_name: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(models.LlmConvoMonitor.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int): + try: + query = self.db.select(models.LlmConvoMonitor.id).where(models.LlmConvoMonitor.course_name == course_name) + if last_id != 0: + query = query.filter(models.LlmConvoMonitor.id > last_id) + count_query = self.db.select(self.db.func.count()).select_from(query.subquery()) + self.db.session.execute(count_query).scalar() + return DatabaseResponse[models.LlmConvoMonitor](data=[], count=1) + finally: + self.db.session.close() + + def getDocMapFromProjects(self, course_name: str): + try: + query = self.db.select(models.Project.doc_map_id).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + documents: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=documents, count=len(result)) + finally: + self.db.session.close() + + def getConvoMapFromProjects(self, course_name: str): + try: + query = self.db.select(models.Project).where(models.Project.course_name == course_name) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.Project] = [doc for doc in result] + return DatabaseResponse[models.Project](data=conversations, count=len(result)) + finally: + self.db.session.close() + + def updateProjects(self, course_name: str, data: dict): + try: + query = self.db.update(models.Project).where(models.Project.course_name == course_name).values(**data) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def getLatestWorkflowId(self): + try: + query = self.db.select(models.N8nWorkflows.latest_workflow_id) + result = self.db.session.execute(query).fetchone() + return result + finally: + self.db.session.close() + + def lockWorkflow(self, id: str): + try: + new_workflow = models.N8nWorkflows(is_locked=True) + self.db.session.add(new_workflow) + self.db.session.commit() + finally: + self.db.session.close() + + def deleteLatestWorkflowId(self, id: str): + try: + query = self.db.delete(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def unlockWorkflow(self, id: str): + try: + query = self.db.update(models.N8nWorkflows).where(models.N8nWorkflows.latest_workflow_id == id).values(is_locked=False) + self.db.session.execute(query) + self.db.session.commit() + finally: + self.db.session.close() + + def getConversation(self, course_name: str, key: str, value: str): + try: + query = self.db.select(models.LlmConvoMonitor).where(getattr(models.LlmConvoMonitor, key) == value) + result = self.db.session.execute(query).scalars().all() + conversations: List[models.LlmConvoMonitor] = [doc for doc in result] + return DatabaseResponse[models.LlmConvoMonitor](data=conversations, count=len(result)) + finally: + self.db.session.close() diff --git a/ai_ta_backend/database/supabase.py b/ai_ta_backend/database/supabase.py index b0bbbdb8..38f7252d 100644 --- a/ai_ta_backend/database/supabase.py +++ b/ai_ta_backend/database/supabase.py @@ -1,7 +1,8 @@ import os -import supabase from injector import inject +import supabase + class SQLDatabase(): @@ -17,20 +18,20 @@ def getAllMaterialsForCourse(self, course_name: str): 'course_name', course_name).execute() def getMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( - 's3_path', s3_path).eq('course_name', course_name).execute() + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq('s3_path', s3_path).eq( + 'course_name', course_name).execute() def getMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq( - key, value).eq('course_name', course_name).execute() + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).select("id, s3_path, contexts").eq(key, value).eq( + 'course_name', course_name).execute() def deleteMaterialsForCourseAndKeyAndValue(self, course_name: str, key: str, value: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq(key, value).eq( - 'course_name', course_name).execute() + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq(key, value).eq('course_name', + course_name).execute() def deleteMaterialsForCourseAndS3Path(self, course_name: str, s3_path: str): - return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( - 'course_name', course_name).execute() + return self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', + s3_path).eq('course_name', course_name).execute() def getProjectsMapForCourse(self, course_name: str): return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() @@ -46,79 +47,91 @@ def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: st elif from_date != '' and to_date == '': # query from from_date to now print("only from_date") - response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).gte( - 'created_at', from_date).order('id', desc=False).execute() + response = self.supabase_client.table(table_name).select("id", + count='exact').eq("course_name", + course_name).gte('created_at', + from_date).order('id', + desc=False).execute() elif from_date == '' and to_date != '': # query from beginning to to_date print("only to_date") - response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).lte( - 'created_at', to_date).order('id', desc=False).execute() + response = self.supabase_client.table(table_name).select("id", + count='exact').eq("course_name", + course_name).lte('created_at', + to_date).order('id', + desc=False).execute() else: # query all data print("No dates") - response = self.supabase_client.table(table_name).select("id", count='exact').eq( - "course_name", course_name).order('id', desc=False).execute() + response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", + course_name).order('id', desc=False).execute() return response def getAllFromTableForDownloadType(self, course_name: str, download_type: str, first_id: int): if download_type == 'documents': - response = self.supabase_client.table("documents").select("*").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(100).execute() + response = self.supabase_client.table("documents").select("*").eq("course_name", + course_name).gte('id', + first_id).order('id', + desc=False).limit(100).execute() else: - response = self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(100).execute() + response = self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte('id', first_id).order( + 'id', desc=False).limit(100).execute() return response def getAllConversationsBetweenIds(self, course_name: str, first_id: int, last_id: int, limit: int = 50): if last_id == 0: - return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gt( - 'id', first_id).order('id', desc=False).limit(limit).execute() + return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gt('id', first_id).order( + 'id', desc=False).limit(limit).execute() else: - return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte( - 'id', first_id).lte('id', last_id).order('id', desc=False).limit(limit).execute() - + return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).gte('id', first_id).lte( + 'id', last_id).order('id', desc=False).limit(limit).execute() def getDocsForIdsGte(self, course_name: str, first_id: int, fields: str = "*", limit: int = 100): - return self.supabase_client.table("documents").select(fields).eq("course_name", course_name).gte( - 'id', first_id).order('id', desc=False).limit(limit).execute() + return self.supabase_client.table("documents").select(fields).eq("course_name", + course_name).gte('id', + first_id).order('id', + desc=False).limit(limit).execute() def insertProjectInfo(self, project_info): return self.supabase_client.table("projects").insert(project_info).execute() def getAllFromLLMConvoMonitor(self, course_name: str): return self.supabase_client.table("llm-convo-monitor").select("*").eq("course_name", course_name).order('id', desc=False).execute() - + def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int): if last_id == 0: - return self.supabase_client.table("llm-convo-monitor").select("id", count='exact').eq("course_name", course_name).order('id', desc=False).execute() + return self.supabase_client.table("llm-convo-monitor").select("id", count='exact').eq("course_name", + course_name).order('id', desc=False).execute() else: - return self.supabase_client.table("llm-convo-monitor").select("id", count='exact').eq("course_name", course_name).gt("id", last_id).order('id', desc=False).execute() - + return self.supabase_client.table("llm-convo-monitor").select("id", + count='exact').eq("course_name", + course_name).gt("id", + last_id).order('id', + desc=False).execute() + def getDocMapFromProjects(self, course_name: str): return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute() - + def getConvoMapFromProjects(self, course_name: str): return self.supabase_client.table("projects").select("*").eq("course_name", course_name).execute() - + def updateProjects(self, course_name: str, data: dict): return self.supabase_client.table("projects").update(data).eq("course_name", course_name).execute() - + def getLatestWorkflowId(self): return self.supabase_client.table('n8n_workflows').select("latest_workflow_id").execute() - + def lockWorkflow(self, id: str): return self.supabase_client.table('n8n_workflows').insert({"latest_workflow_id": id, "is_locked": True}).execute() - + def deleteLatestWorkflowId(self, id: str): return self.supabase_client.table('n8n_workflows').delete().eq('latest_workflow_id', id).execute() - + def unlockWorkflow(self, id: str): return self.supabase_client.table('n8n_workflows').update({"is_locked": False}).eq('latest_workflow_id', id).execute() def getConversation(self, course_name: str, key: str, value: str): return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute() - - diff --git a/ai_ta_backend/executors/process_pool_executor.py b/ai_ta_backend/executors/process_pool_executor.py index 81b4860c..33dc21aa 100644 --- a/ai_ta_backend/executors/process_pool_executor.py +++ b/ai_ta_backend/executors/process_pool_executor.py @@ -24,8 +24,7 @@ def __init__(self, max_workers=None): def submit(self, fn, *args, **kwargs): raise NotImplementedError( - "ProcessPoolExecutorAdapter does not support 'submit' directly due to its nature. Use 'map' or other methods as needed." - ) + "ProcessPoolExecutorAdapter does not support 'submit' directly due to its nature. Use 'map' or other methods as needed.") def map(self, fn, *iterables, timeout=None, chunksize=1): return self.executor.map(fn, *iterables, timeout=timeout, chunksize=chunksize) diff --git a/ai_ta_backend/extensions.py b/ai_ta_backend/extensions.py index 589c64fc..f0b13d6f 100644 --- a/ai_ta_backend/extensions.py +++ b/ai_ta_backend/extensions.py @@ -1,2 +1,3 @@ from flask_sqlalchemy import SQLAlchemy -db = SQLAlchemy() \ No newline at end of file + +db = SQLAlchemy() diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 4a302515..44ddc874 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,4 +1,3 @@ -import json import logging import os import sys @@ -6,35 +5,33 @@ from typing import List from dotenv import load_dotenv -from flask import ( - Flask, - Response, - abort, - jsonify, - make_response, - request, - send_from_directory, -) +from flask import abort +from flask import Flask +from flask import jsonify +from flask import make_response +from flask import request +from flask import Response +from flask import send_from_directory from flask_cors import CORS from flask_executor import Executor -from flask_injector import FlaskInjector, RequestScope -from injector import Binder, SingletonScope +from flask_injector import FlaskInjector +from flask_injector import RequestScope +from injector import Binder +from injector import SingletonScope from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.qdrant import VectorDatabase from ai_ta_backend.database.sql import SQLAlchemyDatabase -from ai_ta_backend.executors.flask_executor import ( - ExecutorInterface, - FlaskExecutorAdapter, -) -from ai_ta_backend.executors.process_pool_executor import ( - ProcessPoolExecutorAdapter, - ProcessPoolExecutorInterface, -) -from ai_ta_backend.executors.thread_pool_executor import ( - ThreadPoolExecutorAdapter, - ThreadPoolExecutorInterface, -) +from ai_ta_backend.executors.flask_executor import ExecutorInterface +from ai_ta_backend.executors.flask_executor import FlaskExecutorAdapter +from ai_ta_backend.executors.process_pool_executor import \ + ProcessPoolExecutorAdapter +from ai_ta_backend.executors.process_pool_executor import \ + ProcessPoolExecutorInterface +from ai_ta_backend.executors.thread_pool_executor import \ + ThreadPoolExecutorAdapter +from ai_ta_backend.executors.thread_pool_executor import \ + ThreadPoolExecutorInterface from ai_ta_backend.extensions import db from ai_ta_backend.service.export_service import ExportService from ai_ta_backend.service.nomic_service import NomicService @@ -68,8 +65,7 @@ def index() -> Response: Returns: JSON: _description_ """ - response = jsonify( - {"hi there, this is a 404": "Welcome to UIUC.chat backend 🚅 Read the docs here: https://docs.uiuc.chat/ "}) + response = jsonify({"hi there, this is a 404": "Welcome to UIUC.chat backend 🚅 Read the docs here: https://docs.uiuc.chat/ "}) response.headers.add('Access-Control-Allow-Origin', '*') return response @@ -141,9 +137,7 @@ def getAll(service: RetrievalService) -> Response: if course_name == '': # proper web error "400 Bad request" - abort( - 400, - description=f"Missing the one required parameter: 'course_name' must be provided. Course name: `{course_name}`") + abort(400, description=f"Missing the one required parameter: 'course_name' must be provided. Course name: `{course_name}`") distinct_dicts = service.getAll(course_name) @@ -261,7 +255,7 @@ def logToNomic(service: NomicService, flaskExecutor: ExecutorInterface): # background execution of tasks!! #response = flaskExecutor.submit(service.log_convo_to_nomic, course_name, data) - result = flaskExecutor.submit(service.log_to_conversation_map, course_name, conversation).result() + flaskExecutor.submit(service.log_to_conversation_map, course_name, conversation).result() response = jsonify({'outcome': 'success'}) response.headers.add('Access-Control-Allow-Origin', '*') return response @@ -289,8 +283,7 @@ def export_convo_history(service: ExportService): response.headers.add('Access-Control-Allow-Origin', '*') else: - response = make_response( - send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) + response = make_response(send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) response.headers.add('Access-Control-Allow-Origin', '*') response.headers["Content-Disposition"] = f"attachment; filename={export_status['response'][1]}" os.remove(export_status['response'][0]) @@ -307,7 +300,7 @@ def export_conversations_custom(service: ExportService): if course_name == '' and emails == []: # proper web error "400 Bad request" - abort(400, description=f"Missing required parameter: 'course_name' and 'destination_email_ids' must be provided.") + abort(400, description="Missing required parameter: 'course_name' and 'destination_email_ids' must be provided.") export_status = service.export_conversations(course_name, from_date, to_date, emails) print("EXPORT FILE LINKS: ", export_status) @@ -321,8 +314,7 @@ def export_conversations_custom(service: ExportService): response.headers.add('Access-Control-Allow-Origin', '*') else: - response = make_response( - send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) + response = make_response(send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) response.headers.add('Access-Control-Allow-Origin', '*') response.headers["Content-Disposition"] = f"attachment; filename={export_status['response'][1]}" os.remove(export_status['response'][0]) @@ -352,8 +344,7 @@ def exportDocuments(service: ExportService): response.headers.add('Access-Control-Allow-Origin', '*') else: - response = make_response( - send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) + response = make_response(send_from_directory(export_status['response'][2], export_status['response'][1], as_attachment=True)) response.headers.add('Access-Control-Allow-Origin', '*') response.headers["Content-Disposition"] = f"attachment; filename={export_status['response'][1]}" os.remove(export_status['response'][0]) diff --git a/ai_ta_backend/modal/pest_detection.py b/ai_ta_backend/modal/pest_detection.py index 1500a891..8f514efa 100644 --- a/ai_ta_backend/modal/pest_detection.py +++ b/ai_ta_backend/modal/pest_detection.py @@ -16,14 +16,18 @@ import inspect import json import os -import traceback -import uuid from tempfile import NamedTemporaryFile +import traceback from typing import List +import uuid -import modal from fastapi import Request -from modal import Secret, Stub, build, enter, web_endpoint +import modal +from modal import build +from modal import enter +from modal import Secret +from modal import Stub +from modal import web_endpoint # Simpler image, but slower cold starts: modal.Image.from_registry('ultralytics/ultralytics:latest-cpu') image = ( @@ -42,16 +46,10 @@ # Imports needed inside the image with image.imports(): - import inspect - import os - import traceback - import uuid - from tempfile import NamedTemporaryFile - from typing import List import boto3 - import requests from PIL import Image + import requests from ultralytics import YOLO diff --git a/ai_ta_backend/model/models.py b/ai_ta_backend/model/models.py index 6b35ab44..cac15cb7 100644 --- a/ai_ta_backend/model/models.py +++ b/ai_ta_backend/model/models.py @@ -1,89 +1,95 @@ -from sqlalchemy import Column, BigInteger, Text, DateTime, Boolean, ForeignKey, Index, JSON +from sqlalchemy import BigInteger +from sqlalchemy import Boolean +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import ForeignKey +from sqlalchemy import Index +from sqlalchemy import JSON +from sqlalchemy import Text from sqlalchemy.sql import func + from ai_ta_backend.extensions import db + class Base(db.Model): - __abstract__ = True + __abstract__ = True + class Document(Base): - __tablename__ = 'documents' - id = Column(BigInteger, primary_key=True, autoincrement=True) - created_at = Column(DateTime, default=func.now()) - s3_path = Column(Text) - readable_filename = Column(Text) - course_name = Column(Text) - url = Column(Text) - contexts = Column(JSON, default=lambda: [ - { - "text": "", - "timestamp": "", - "embedding": "", - "pagenumber": "" - } - ]) - base_url = Column(Text) - - __table_args__ = ( - Index('documents_course_name_idx', 'course_name', postgresql_using='hash'), - Index('documents_created_at_idx', 'created_at', postgresql_using='btree'), - Index('idx_doc_s3_path', 's3_path', postgresql_using='btree'), - ) + __tablename__ = 'documents' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + s3_path = Column(Text) + readable_filename = Column(Text) + course_name = Column(Text) + url = Column(Text) + contexts = Column(JSON, default=lambda: [{"text": "", "timestamp": "", "embedding": "", "pagenumber": ""}]) + base_url = Column(Text) + + __table_args__ = ( + Index('documents_course_name_idx', 'course_name', postgresql_using='hash'), + Index('documents_created_at_idx', 'created_at', postgresql_using='btree'), + Index('idx_doc_s3_path', 's3_path', postgresql_using='btree'), + ) + class DocumentDocGroup(Base): - __tablename__ = 'documents_doc_groups' - document_id = Column(BigInteger, primary_key=True) - doc_group_id = Column(BigInteger, ForeignKey('doc_groups.id', ondelete='CASCADE'), primary_key=True) - created_at = Column(DateTime, default=func.now()) + __tablename__ = 'documents_doc_groups' + document_id = Column(BigInteger, primary_key=True) + doc_group_id = Column(BigInteger, ForeignKey('doc_groups.id', ondelete='CASCADE'), primary_key=True) + created_at = Column(DateTime, default=func.now()) + + __table_args__ = ( + Index('documents_doc_groups_doc_group_id_idx', 'doc_group_id', postgresql_using='btree'), + Index('documents_doc_groups_document_id_idx', 'document_id', postgresql_using='btree'), + ) - __table_args__ = ( - Index('documents_doc_groups_doc_group_id_idx', 'doc_group_id', postgresql_using='btree'), - Index('documents_doc_groups_document_id_idx', 'document_id', postgresql_using='btree'), - ) class DocGroup(Base): - __tablename__ = 'doc_groups' - id = Column(BigInteger, primary_key=True, autoincrement=True) - name = Column(Text, nullable=False) - course_name = Column(Text, nullable=False) - created_at = Column(DateTime, default=func.now()) - enabled = Column(Boolean, default=True) - private = Column(Boolean, default=True) - doc_count = Column(BigInteger) - - __table_args__ = ( - Index('doc_groups_enabled_course_name_idx', 'enabled', 'course_name', postgresql_using='btree'), - ) + __tablename__ = 'doc_groups' + id = Column(BigInteger, primary_key=True, autoincrement=True) + name = Column(Text, nullable=False) + course_name = Column(Text, nullable=False) + created_at = Column(DateTime, default=func.now()) + enabled = Column(Boolean, default=True) + private = Column(Boolean, default=True) + doc_count = Column(BigInteger) + + __table_args__ = (Index('doc_groups_enabled_course_name_idx', 'enabled', 'course_name', postgresql_using='btree'),) + class Project(Base): - __tablename__ = 'projects' - id = Column(BigInteger, primary_key=True, autoincrement=True) - created_at = Column(DateTime, default=func.now()) - course_name = Column(Text) - doc_map_id = Column(Text) - convo_map_id = Column(Text) - n8n_api_key = Column(Text) - last_uploaded_doc_id = Column(BigInteger) - last_uploaded_convo_id = Column(BigInteger) - subscribed = Column(BigInteger, ForeignKey('doc_groups.id', onupdate='CASCADE', ondelete='SET NULL')) + __tablename__ = 'projects' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + course_name = Column(Text) + doc_map_id = Column(Text) + convo_map_id = Column(Text) + n8n_api_key = Column(Text) + last_uploaded_doc_id = Column(BigInteger) + last_uploaded_convo_id = Column(BigInteger) + subscribed = Column(BigInteger, ForeignKey('doc_groups.id', onupdate='CASCADE', ondelete='SET NULL')) + class N8nWorkflows(Base): - __tablename__ = 'n8n_workflows' - latest_workflow_id = Column(BigInteger, primary_key=True, autoincrement=True) - is_locked = Column(Boolean, nullable=False) + __tablename__ = 'n8n_workflows' + latest_workflow_id = Column(BigInteger, primary_key=True, autoincrement=True) + is_locked = Column(Boolean, nullable=False) + + def __init__(self, is_locked: bool): + self.is_locked = is_locked - def __init__(self, is_locked: bool): - self.is_locked = is_locked class LlmConvoMonitor(Base): - __tablename__ = 'llm_convo_monitor' - id = Column(BigInteger, primary_key=True, autoincrement=True) - created_at = Column(DateTime, default=func.now()) - convo = Column(JSON) - convo_id = Column(Text, unique=True) - course_name = Column(Text) - user_email = Column(Text) - - __table_args__ = ( - Index('llm_convo_monitor_course_name_idx', 'course_name', postgresql_using='hash'), - Index('llm_convo_monitor_convo_id_idx', 'convo_id', postgresql_using='hash'), - ) \ No newline at end of file + __tablename__ = 'llm_convo_monitor' + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(DateTime, default=func.now()) + convo = Column(JSON) + convo_id = Column(Text, unique=True) + course_name = Column(Text) + user_email = Column(Text) + + __table_args__ = ( + Index('llm_convo_monitor_course_name_idx', 'course_name', postgresql_using='hash'), + Index('llm_convo_monitor_convo_id_idx', 'convo_id', postgresql_using='hash'), + ) diff --git a/ai_ta_backend/model/response.py b/ai_ta_backend/model/response.py index 2263d8e3..f0ae5f07 100644 --- a/ai_ta_backend/model/response.py +++ b/ai_ta_backend/model/response.py @@ -1,9 +1,12 @@ -from typing import List, TypeVar, Generic +from typing import Generic, List, TypeVar + from flask_sqlalchemy.model import Model T = TypeVar('T', bound=Model) + class DatabaseResponse(Generic[T]): - def __init__(self, data: List[T], count: int): - self.data = data - self.count = count \ No newline at end of file + + def __init__(self, data: List[T], count: int): + self.data = data + self.count = count diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index 30a959c6..c998d001 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -1,18 +1,18 @@ +from concurrent.futures import ProcessPoolExecutor import json import os import uuid import zipfile -from concurrent.futures import ProcessPoolExecutor +from injector import inject import pandas as pd import requests -from injector import inject from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.sql import SQLAlchemyDatabase +from ai_ta_backend.extensions import db from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.emails import send_email -from ai_ta_backend.extensions import db class ExportService: @@ -35,7 +35,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): """ response = self.sql.getDocumentsBetweenDates(course_name, from_date, to_date) - + # add a condition to route to direct download or s3 download if response.count and response.count > 500: # call background task to upload to s3 @@ -228,7 +228,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e return {"response": "Error downloading file!"} else: return {"response": "No data found between the given dates."} - + # Encountered pickling error while running the background task. So, moved the function outside the class. @@ -240,10 +240,10 @@ def export_data_in_bg(response, download_type, course_name, s3_path): 3. send an email to the course admins with the pre-signed URL. Args: - response (dict): The response from the Supabase query. - download_type (str): The type of download - 'documents' or 'conversations'. - course_name (str): The name of the course. - s3_path (str): The S3 path where the file will be uploaded. + response (dict): The response from the Supabase query. + download_type (str): The type of download - 'documents' or 'conversations'. + course_name (str): The name of the course. + s3_path (str): The S3 path where the file will be uploaded. """ s3 = AWSStorage() sql = SQLAlchemyDatabase(db) @@ -301,7 +301,6 @@ def export_data_in_bg(response, download_type, course_name, s3_path): # generate presigned URL s3_url = s3.generatePresignedUrl('get_object', os.environ['S3_BUCKET_NAME'], s3_path, 172800) - # get admin email IDs headers = {"Authorization": f"Bearer {os.environ['VERCEL_READ_ONLY_API_KEY']}", "Content-Type": "application/json"} @@ -344,6 +343,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): print(e) return "Error: " + str(e) + def export_data_in_bg_emails(response, download_type, course_name, s3_path, emails): """ This function is called in export_documents_csv() to upload the documents to S3. @@ -352,10 +352,10 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai 3. send an email to the course admins with the pre-signed URL. Args: - response (dict): The response from the Supabase query. - download_type (str): The type of download - 'documents' or 'conversations'. - course_name (str): The name of the course. - s3_path (str): The S3 path where the file will be uploaded. + response (dict): The response from the Supabase query. + download_type (str): The type of download - 'documents' or 'conversations'. + course_name (str): The name of the course. + s3_path (str): The S3 path where the file will be uploaded. """ s3 = AWSStorage() sql = SQLAlchemyDatabase(db) @@ -415,7 +415,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai admin_emails = emails bcc_emails = [] - + print("admin_emails: ", admin_emails) print("bcc_emails: ", bcc_emails) @@ -438,4 +438,4 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai except Exception as e: print(e) - return "Error: " + str(e) \ No newline at end of file + return "Error: " + str(e) diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index 1900279c..c7011c88 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -1,17 +1,16 @@ import datetime import os import time -from typing import Union -import backoff +from injector import inject +from langchain.embeddings.openai import OpenAIEmbeddings import nomic +from nomic import atlas +from nomic import AtlasProject import numpy as np import pandas as pd -from injector import inject -from langchain.embeddings.openai import OpenAIEmbeddings -from nomic import AtlasProject, atlas -from ai_ta_backend.database.sql import SQLAlchemyDatabase +from ai_ta_backend.database.sql import SQLAlchemyDatabase from ai_ta_backend.service.sentry_service import SentryService LOCK_EXCEPTIONS = [ @@ -55,15 +54,12 @@ def get_nomic_map(self, course_name: str, type: str): except Exception as e: # Error: ValueError: You must specify a unique_id_field when creating a new project. if str(e) == 'You must specify a unique_id_field when creating a new project.': # type: ignore - print( - "Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", - e) + print("Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", e) else: print("ERROR in get_nomic_map():", e) self.sentry.capture_exception(e) return {"map_id": None, "map_link": None} - def log_to_conversation_map(self, course_name: str, conversation): """ This function logs new conversations to existing nomic maps. @@ -82,12 +78,12 @@ def log_to_conversation_map(self, course_name: str, conversation): if not response.data: print("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) - + # entry present for doc map, but not convo map - elif not response.data[0].convo_map_id is None: + elif response.data[0].convo_map_id is not None: print("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) - + project_id = response.data[0].convo_map_id last_uploaded_convo_id: int = int(str(response.data[0].last_uploaded_convo_id)) @@ -154,17 +150,16 @@ def log_to_conversation_map(self, course_name: str, conversation): project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) print("Update response from supabase: ", project_response) - + # rebuild the map self.rebuild_map(course_name, "conversation") return "success" - + except Exception as e: print(e) self.sentry.capture_exception(e) return "Error in logging to conversation map: {e}" - - + def log_to_existing_conversation(self, course_name: str, conversation): """ This function logs follow-up questions to existing conversations in the map. @@ -176,18 +171,18 @@ def log_to_existing_conversation(self, course_name: str, conversation): # fetch id from supabase incoming_id_response = self.sql.getConversation(course_name, key="convo_id", value=conversation_id) - + project_name = 'Conversation Map for ' + course_name project = AtlasProject(name=project_name, add_datums_if_exists=True) prev_id = str(incoming_id_response.data[0].id) - uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic + uploaded_data = project.get_data(ids=[prev_id]) # fetch data point from nomic prev_convo = uploaded_data[0]['conversation'] # update conversation messages = conversation['messages'] messages_to_be_logged = messages[-2:] - + for message in messages_to_be_logged: if message['role'] == 'user': emoji = "🙋 " @@ -200,7 +195,7 @@ def log_to_existing_conversation(self, course_name: str, conversation): text = message['content'] prev_convo += "\n>>> " + emoji + message['role'] + ": " + text + "\n" - + # create embeddings of first query embeddings_model = OpenAIEmbeddings(openai_api_type="openai", openai_api_base="https://api.openai.com/v1/", @@ -228,7 +223,7 @@ def log_to_existing_conversation(self, course_name: str, conversation): # re-insert updated conversation result = self.append_to_map(embeddings, metadata, project_name) print("Result of appending to existing map:", result) - + return "success" except Exception as e: @@ -236,7 +231,6 @@ def log_to_existing_conversation(self, course_name: str, conversation): self.sentry.capture_exception(e) return "Error in logging to existing conversation: {e}" - def create_conversation_map(self, course_name: str): """ This function creates a conversation map for a given course from scratch. @@ -295,8 +289,7 @@ def create_conversation_map(self, course_name: str): index_name = course_name + "_convo_index" topic_label_field = "first_query" colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] - result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, - colorable_fields) + result = self.create_map(embeddings, metadata, project_name, index_name, topic_label_field, colorable_fields) if result == "success": # update flag @@ -333,7 +326,7 @@ def create_conversation_map(self, course_name: str): # set first_id for next iteration try: print("response: ", response.data[-1].id) - except: + except Exception as e: print("response: ", response.data) first_id = int(str(response.data[-1].id)) + 1 @@ -370,7 +363,6 @@ def create_conversation_map(self, course_name: str): project_response = self.sql.insertProjectInfo(project_info) print("Response from supabase: ", project_response) - # rebuild the map self.rebuild_map(course_name, "conversation") return "success" @@ -471,7 +463,6 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): for _index, row in df.iterrows(): current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") created_at = datetime.datetime.strptime(row['created_at'], "%Y-%m-%dT%H:%M:%S.%f%z").strftime("%Y-%m-%d %H:%M:%S") - conversation_exists = False conversation = "" emoji = "" diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 3ff0f5a5..f660b037 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -4,11 +4,11 @@ import traceback from typing import Dict, List, Optional, Union -import openai from injector import inject from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings from langchain.schema import Document +import openai from ai_ta_backend.database.aws import AWSStorage from ai_ta_backend.database.qdrant import VectorDatabase @@ -73,9 +73,7 @@ def getTopContexts(self, try: start_time_overall = time.monotonic() - found_docs: list[Document] = self.vector_search(search_query=search_query, - course_name=course_name, - doc_groups=doc_groups) + found_docs: list[Document] = self.vector_search(search_query=search_query, course_name=course_name, doc_groups=doc_groups) pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # count tokens at start and end, then also count each context. @@ -208,10 +206,7 @@ def delete_from_qdrant(self, identifier_key: str, identifier_value: str): if self.sentry is not None: self.sentry.capture_exception(e) - def getTopContextsWithMQR(self, - search_query: str, - course_name: str, - token_limit: int = 4_000) -> Union[List[Dict], str]: + def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: """ New info-retrieval pipeline that uses multi-query retrieval + filtering + reciprocal rank fusion + context padding. 1. Generate multiple queries based on the input search query. diff --git a/ai_ta_backend/service/sentry_service.py b/ai_ta_backend/service/sentry_service.py index 53b780b0..6c35b066 100644 --- a/ai_ta_backend/service/sentry_service.py +++ b/ai_ta_backend/service/sentry_service.py @@ -1,7 +1,7 @@ import os -import sentry_sdk from injector import inject +import sentry_sdk class SentryService: diff --git a/ai_ta_backend/service/workflow_service.py b/ai_ta_backend/service/workflow_service.py index 0627b53a..42671707 100644 --- a/ai_ta_backend/service/workflow_service.py +++ b/ai_ta_backend/service/workflow_service.py @@ -1,10 +1,11 @@ -import requests -import time +import json import os -import supabase +import time from urllib.parse import quote -import json + from injector import inject +import requests + from ai_ta_backend.database.supabase import SQLDatabase @@ -78,12 +79,7 @@ def get_executions(self, limit, id=None, pagination: bool = True, api_key: str = else: return all_executions - def get_workflows(self, - limit, - pagination: bool = True, - api_key: str = "", - active: bool = False, - workflow_name: str = ''): + def get_workflows(self, limit, pagination: bool = True, api_key: str = "", active: bool = False, workflow_name: str = ''): if not api_key: raise ValueError('api_key is required') headers = {"X-N8N-API-KEY": api_key, "Accept": "application/json"} diff --git a/ai_ta_backend/utils/context_parent_doc_padding.py b/ai_ta_backend/utils/context_parent_doc_padding.py index fc0ba19c..25553248 100644 --- a/ai_ta_backend/utils/context_parent_doc_padding.py +++ b/ai_ta_backend/utils/context_parent_doc_padding.py @@ -1,8 +1,8 @@ -import os -import time from concurrent.futures import ProcessPoolExecutor from functools import partial from multiprocessing import Manager +import os +import time DOCUMENTS_TABLE = os.environ['SUPABASE_DOCUMENTS_TABLE'] # SUPABASE_CLIENT = supabase.create_client(supabase_url=os.environ['SUPABASE_URL'], @@ -68,14 +68,11 @@ def supabase_context_padding(doc, course_name, result_docs): # query by url or s3_path if 'url' in doc.metadata.keys() and doc.metadata['url']: parent_doc_id = doc.metadata['url'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('url', parent_doc_id).execute() + response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', course_name).eq('url', parent_doc_id).execute() else: parent_doc_id = doc.metadata['s3_path'] - response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', - course_name).eq('s3_path', - parent_doc_id).execute() + response = SUPABASE_CLIENT.table(DOCUMENTS_TABLE).select('*').eq('course_name', course_name).eq('s3_path', parent_doc_id).execute() data = response.data diff --git a/ai_ta_backend/utils/emails.py b/ai_ta_backend/utils/emails.py index 4312a35d..1d001bfa 100644 --- a/ai_ta_backend/utils/emails.py +++ b/ai_ta_backend/utils/emails.py @@ -1,7 +1,7 @@ -import os -import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +import os +import smtplib def send_email(subject: str, body_text: str, sender: str, receipients: list, bcc_receipients: list): diff --git a/ai_ta_backend/utils/utils_tokenization.py b/ai_ta_backend/utils/utils_tokenization.py index 956cc196..db766e50 100644 --- a/ai_ta_backend/utils/utils_tokenization.py +++ b/ai_ta_backend/utils/utils_tokenization.py @@ -4,10 +4,9 @@ import tiktoken -def count_tokens_and_cost( - prompt: str, - completion: str = '', - openai_model_name: str = "gpt-3.5-turbo"): # -> tuple[int, float] | tuple[int, float, int, float]: +def count_tokens_and_cost(prompt: str, + completion: str = '', + openai_model_name: str = "gpt-3.5-turbo"): # -> tuple[int, float] | tuple[int, float, int, float]: """ # TODO: improve w/ extra tokens used by model: https://github.com/openai/openai-cookbook/blob/d00e9a48a63739f5b038797594c81c8bb494fc09/examples/How_to_count_tokens_with_tiktoken.ipynb Returns the number of tokens in a text string. @@ -126,9 +125,7 @@ def analyze_conversations(supabase_client: Any = None): # If the message is from the assistant, it's a completion elif role == 'assistant': - num_tokens_completion, cost_completion = count_tokens_and_cost(prompt='', - completion=content, - openai_model_name=model_name) + num_tokens_completion, cost_completion = count_tokens_and_cost(prompt='', completion=content, openai_model_name=model_name) total_completion_cost += cost_completion print(f'Assistant Completion: {content}\nTokens: {num_tokens_completion}, cost: {cost_completion}') return total_convos, total_messages, total_prompt_cost, total_completion_cost diff --git a/docker-compose.yaml b/docker-compose.yaml index 9aa5d1e2..9a5686f0 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,4 +1,4 @@ -version: '3.8' +version: "3.8" services: # sqlite: @@ -26,8 +26,8 @@ services: minio: image: minio/minio environment: - MINIO_ROOT_USER: 'minioadmin' # Customize access key - MINIO_ROOT_PASSWORD: 'minioadmin' # Customize secret key + MINIO_ROOT_USER: "minioadmin" # Customize access key + MINIO_ROOT_PASSWORD: "minioadmin" # Customize secret key command: server /data ports: - "9000:9000" # Console access diff --git a/railway.json b/railway.json index 4147197e..9810ca27 100644 --- a/railway.json +++ b/railway.json @@ -13,10 +13,7 @@ ] }, "setup": { - "nixPkgs": [ - "python310", - "gcc" - ] + "nixPkgs": ["python310", "gcc"] } } } @@ -26,4 +23,4 @@ "restartPolicyType": "ON_FAILURE", "restartPolicyMaxRetries": 1 } -} \ No newline at end of file +} From 89d4abc9565cd6005bb803470ea559b2d2d8abea Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Thu, 20 Jun 2024 13:23:11 -0700 Subject: [PATCH 09/11] Fix print -> logging.info. Fix Qdrant connection refused w/ DockerCompose network --- .gitignore | 1 + README.md | 8 +- ai_ta_backend/beam/OpenaiEmbeddings.py | 20 +- ai_ta_backend/beam/ingest.py | 130 +++++------ ai_ta_backend/beam/nomic_logging.py | 60 +++--- ai_ta_backend/database/aws.py | 5 +- ai_ta_backend/database/qdrant.py | 7 +- ai_ta_backend/database/supabase.py | 9 +- ai_ta_backend/main.py | 45 ++-- ai_ta_backend/modal/pest_detection.py | 11 +- ai_ta_backend/service/export_service.py | 67 +++--- ai_ta_backend/service/nomic_service.py | 109 +++++----- ai_ta_backend/service/retrieval_service.py | 65 +++--- ai_ta_backend/service/workflow_service.py | 23 +- .../utils/context_parent_doc_padding.py | 13 +- ai_ta_backend/utils/filtering_contexts.py | 28 +-- ai_ta_backend/utils/utils_tokenization.py | 19 +- docker-compose.yaml | 75 +++++-- qdrant_config.yaml | 204 ++++++++++++++++++ 19 files changed, 577 insertions(+), 322 deletions(-) create mode 100644 qdrant_config.yaml diff --git a/.gitignore b/.gitignore index b0391b88..ee82f76d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ coursera-dl/ wandb *.ipynb *.pem +qdrant_data/* # don't expose env files .env diff --git a/README.md b/README.md index a049e2b1..c2e2b895 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,11 @@ Automatic [API Reference](https://uiuc-chatbot.github.io/ai-ta-backend/reference ## Docker Deployment -1. Build flask image `docker build -t kastanday/ai-ta-backend:gunicorn .` -2. Push flask image `docker push kastanday/ai-ta-backend:gunicorn` -3. Run docker compose `docker compose up` +1. Just run Docker Compose `docker compose up --build` + +Works on version: `Docker Compose version v2.27.1-desktop.1` + +Works on Apple Silicon M1 `aarch64`, and `x86`. ## 📣 Development diff --git a/ai_ta_backend/beam/OpenaiEmbeddings.py b/ai_ta_backend/beam/OpenaiEmbeddings.py index b3a088c2..eb7532db 100644 --- a/ai_ta_backend/beam/OpenaiEmbeddings.py +++ b/ai_ta_backend/beam/OpenaiEmbeddings.py @@ -224,9 +224,9 @@ # task_list.append(task) # next_request = None # reset next_request to empty -# # print("status_tracker.num_tasks_in_progress", status_tracker.num_tasks_in_progress) +# # logging.info("status_tracker.num_tasks_in_progress", status_tracker.num_tasks_in_progress) # # one_task_result = task.result() -# # print("one_task_result", one_task_result) +# # logging.info("one_task_result", one_task_result) # # if all tasks are finished, break # if status_tracker.num_tasks_in_progress == 0: @@ -485,8 +485,8 @@ # # total_prompt_tokens = 0 # # total_completion_tokens = 0 -# # print("Results, end of main: ", oai.results) -# # print("-"*50) +# # logging.info("Results, end of main: ", oai.results) +# # logging.info("-"*50) # # # jsonObject = json.loads(oai.results) # # for element in oai.results: @@ -498,15 +498,15 @@ # # total_prompt_tokens += item['usage']['prompt_tokens'] # # total_completion_tokens += item['usage']['completion_tokens'] -# # print("Assistant Contents:", assistant_contents) -# # print("Total Prompt Tokens:", total_prompt_tokens) -# # print("Total Completion Tokens:", total_completion_tokens) +# # logging.info("Assistant Contents:", assistant_contents) +# # logging.info("Total Prompt Tokens:", total_prompt_tokens) +# # logging.info("Total Completion Tokens:", total_completion_tokens) # # turbo_total_cost = (total_prompt_tokens * 0.0015) + (total_completion_tokens * 0.002) -# # print("Total cost (3.5-turbo):", (total_prompt_tokens * 0.0015), " + Completions: ", (total_completion_tokens * 0.002), " = ", turbo_total_cost) +# # logging.info("Total cost (3.5-turbo):", (total_prompt_tokens * 0.0015), " + Completions: ", (total_completion_tokens * 0.002), " = ", turbo_total_cost) # # gpt4_total_cost = (total_prompt_tokens * 0.03) + (total_completion_tokens * 0.06) -# # print("Hypothetical cost for GPT-4:", (total_prompt_tokens * 0.03), " + Completions: ", (total_completion_tokens * 0.06), " = ", gpt4_total_cost) -# # print("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") +# # logging.info("Hypothetical cost for GPT-4:", (total_prompt_tokens * 0.03), " + Completions: ", (total_completion_tokens * 0.06), " = ", gpt4_total_cost) +# # logging.info("GPT-4 cost premium: ", (gpt4_total_cost / turbo_total_cost), "x") # ''' # Pricing: # GPT4: diff --git a/ai_ta_backend/beam/ingest.py b/ai_ta_backend/beam/ingest.py index aafe814d..42174f27 100644 --- a/ai_ta_backend/beam/ingest.py +++ b/ai_ta_backend/beam/ingest.py @@ -164,7 +164,7 @@ # readable_filename: List[str] | str = inputs.get('readable_filename', '') # content: str | None = inputs.get('content', None) # is webtext if content exists -# print( +# logging.info( # f"In top of /ingest route. course: {course_name}, s3paths: {s3_paths}, readable_filename: {readable_filename}, base_url: {base_url}, url: {url}, content: {content}" # ) @@ -189,11 +189,11 @@ # num_retires = 5 # for retry_num in range(1, num_retires): # if isinstance(success_fail_dict, str): -# print(f"STRING ERROR: {success_fail_dict = }") +# logging.info(f"STRING ERROR: {success_fail_dict = }") # success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) # time.sleep(13 * retry_num) # max is 65 # elif success_fail_dict['failure_ingest']: -# print(f"Ingest failure -- Retry attempt {retry_num}. File: {success_fail_dict}") +# logging.info(f"Ingest failure -- Retry attempt {retry_num}. File: {success_fail_dict}") # # s3_paths = success_fail_dict['failure_ingest'] # retry only failed paths.... what if this is a URL instead? # success_fail_dict = run_ingest(course_name, s3_paths, base_url, url, readable_filename, content) # time.sleep(13 * retry_num) # max is 65 @@ -202,7 +202,7 @@ # # Final failure / success check # if success_fail_dict['failure_ingest']: -# print(f"INGEST FAILURE -- About to send to supabase. success_fail_dict: {success_fail_dict}") +# logging.info(f"INGEST FAILURE -- About to send to supabase. success_fail_dict: {success_fail_dict}") # document = { # "course_name": # course_name, @@ -219,13 +219,13 @@ # if isinstance(success_fail_dict['failure_ingest'], dict) else success_fail_dict['failure_ingest'] # } # response = supabase_client.table('documents_failed').insert(document).execute() # type: ignore -# print(f"Supabase ingest failure response: {response}") +# logging.info(f"Supabase ingest failure response: {response}") # else: # # Success case: rebuild nomic document map after all ingests are done # # rebuild_status = rebuild_map(str(course_name), map_type='document') # pass -# print(f"Final success_fail_dict: {success_fail_dict}") +# logging.info(f"Final success_fail_dict: {success_fail_dict}") # return json.dumps(success_fail_dict) # class Ingest(): @@ -281,7 +281,7 @@ # } # # 👆👆👆👆 ADD NEW INGEST METHODhe 👆👆👆👆🎉 -# print(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}") +# logging.info(f"Top of ingest, Course_name {course_name}. S3 paths {s3_paths}") # success_status: Dict[str, None | str | Dict[str, str]] = {"success_ingest": None, "failure_ingest": None} # try: # if isinstance(s3_paths, str): @@ -300,7 +300,7 @@ # _ingest_single(ingest_method, s3_path, course_name, **kwargs) # elif mime_category in mimetype_ingest_methods: # # fallback to MimeType -# print("mime category", mime_category) +# logging.info("mime category", mime_category) # ingest_method = mimetype_ingest_methods[mime_category] # _ingest_single(ingest_method, s3_path, course_name, **kwargs) # else: @@ -308,9 +308,9 @@ # try: # self._ingest_single_txt(s3_path, course_name) # success_status['success_ingest'] = s3_path -# print(f"No ingest methods -- Falling back to UTF-8 INGEST... s3_path = {s3_path}") +# logging.info(f"No ingest methods -- Falling back to UTF-8 INGEST... s3_path = {s3_path}") # except Exception as e: -# print( +# logging.info( # f"We don't have a ingest method for this filetype: {file_extension}. As a last-ditch effort, we tried to ingest the file as utf-8 text, but that failed too. File is unsupported: {s3_path}. UTF-8 ingest error: {e}" # ) # success_status['failure_ingest'] = { @@ -349,7 +349,7 @@ # }) # sentry_sdk.capture_exception(e) -# print(f"MAJOR ERROR IN /bulk_ingest: {str(e)}") +# logging.info(f"MAJOR ERROR IN /bulk_ingest: {str(e)}") # return success_status # def ingest_single_web_text(self, course_name: str, base_url: str, url: str, content: str, readable_filename: str): @@ -392,7 +392,7 @@ # except Exception as e: # err = f"❌❌ Error in (web text ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) # type: ignore -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # success_or_failure['failure_ingest'] = {'url': url, 'error': str(err)} # return success_or_failure @@ -419,17 +419,17 @@ # 'url': '', # 'base_url': '', # } for doc in documents] -# #print(texts) +# #logging.info(texts) # os.remove(file_path) # success_or_failure = self.split_and_upload(texts=texts, metadatas=metadatas) -# print("Python ingest: ", success_or_failure) +# logging.info("Python ingest: ", success_or_failure) # return success_or_failure # except Exception as e: # err = f"❌❌ Error in (Python ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -461,12 +461,12 @@ # except Exception as e: # err = f"❌❌ Error in (VTT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err # def _ingest_html(self, s3_path: str, course_name: str, **kwargs) -> str: -# print(f"IN _ingest_html s3_path `{s3_path}` kwargs: {kwargs}") +# logging.info(f"IN _ingest_html s3_path `{s3_path}` kwargs: {kwargs}") # try: # response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) # raw_html = response['Body'].read().decode('utf-8') @@ -491,11 +491,11 @@ # }] # success_or_failure = self.split_and_upload(text, metadata) -# print(f"_ingest_html: {success_or_failure}") +# logging.info(f"_ingest_html: {success_or_failure}") # return success_or_failure # except Exception as e: # err: str = f"ERROR IN _ingest_html: {e}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -503,7 +503,7 @@ # """ # Ingest a single video file from S3. # """ -# print("Starting ingest video or audio") +# logging.info("Starting ingest video or audio") # try: # # Ensure the media directory exists # media_dir = "media" @@ -578,7 +578,7 @@ # except Exception as e: # err = f"❌❌ Error in (VIDEO ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -607,7 +607,7 @@ # except Exception as e: # err = f"❌❌ Error in (DOCX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -619,10 +619,10 @@ # response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) # raw_text = response['Body'].read().decode('utf-8') -# print("UTF-8 text to ingest as SRT:", raw_text) +# logging.info("UTF-8 text to ingest as SRT:", raw_text) # parsed_info = pysrt.from_string(raw_text) # text = " ".join([t.text for t in parsed_info]) # type: ignore -# print(f"Final SRT ingest: {text}") +# logging.info(f"Final SRT ingest: {text}") # texts = [text] # metadatas: List[Dict[str, Any]] = [{ @@ -643,7 +643,7 @@ # except Exception as e: # err = f"❌❌ Error in (SRT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -674,7 +674,7 @@ # except Exception as e: # err = f"❌❌ Error in (Excel/xlsx ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -691,7 +691,7 @@ # """ # res_str = pytesseract.image_to_string(Image.open(tmpfile.name)) -# print("IMAGE PARSING RESULT:", res_str) +# logging.info("IMAGE PARSING RESULT:", res_str) # documents = [Document(page_content=res_str)] # texts = [doc.page_content for doc in documents] @@ -711,7 +711,7 @@ # except Exception as e: # err = f"❌❌ Error in (png/jpg ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -741,7 +741,7 @@ # except Exception as e: # err = f"❌❌ Error in (CSV ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -751,7 +751,7 @@ # LangChain `Documents` have .metadata and .page_content attributes. # Be sure to use TemporaryFile() to avoid memory leaks! # """ -# print("IN PDF ingest: s3_path: ", s3_path, "and kwargs:", kwargs) +# logging.info("IN PDF ingest: s3_path: ", s3_path, "and kwargs:", kwargs) # try: # with NamedTemporaryFile() as pdf_tmpfile: @@ -761,7 +761,7 @@ # try: # doc = fitz.open(pdf_tmpfile.name) # type: ignore # except fitz.fitz.EmptyFileError as e: -# print(f"Empty PDF file: {s3_path}") +# logging.info(f"Empty PDF file: {s3_path}") # return "Failed ingest: Could not detect ANY text in the PDF. OCR did not help. PDF appears empty of text." # # improve quality of the image @@ -781,7 +781,7 @@ # s3_upload_path = str(Path(s3_path)).rsplit('.pdf')[0] + "-pg1-thumb.png" # first_page_png.seek(0) # Seek the file pointer back to the beginning # with open(first_page_png.name, 'rb') as f: -# print("Uploading image png to S3") +# logging.info("Uploading image png to S3") # self.s3_client.upload_fileobj(f, os.getenv('S3_BUCKET_NAME'), s3_upload_path) # # Extract text @@ -806,14 +806,14 @@ # if has_words: # success_or_failure = self.split_and_upload(texts=pdf_texts, metadatas=metadatas) # else: -# print("⚠️ PDF IS EMPTY -- OCR-ing the PDF.") +# logging.info("⚠️ PDF IS EMPTY -- OCR-ing the PDF.") # success_or_failure = self._ocr_pdf(s3_path=s3_path, course_name=course_name, **kwargs) # return success_or_failure # except Exception as e: # err = f"❌❌ Error in PDF ingest (no OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) # type: ignore -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err # return "Success" @@ -837,7 +837,7 @@ # for i, page in enumerate(pdf.pages): # im = page.to_image() # text = pytesseract.image_to_string(im.original) -# print("Page number: ", i, "Text: ", text[:100]) +# logging.info("Page number: ", i, "Text: ", text[:100]) # pdf_pages_OCRed.append(dict(text=text, page_number=i, readable_filename=Path(s3_path).name[37:])) # metadatas: List[Dict[str, Any]] = [ @@ -867,7 +867,7 @@ # return success_or_failure # except Exception as e: # err = f"❌❌ Error in PDF ingest (with OCR): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc() -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -879,12 +879,12 @@ # Returns: # str: "Success" or an error message # """ -# print("In text ingest, UTF-8") +# logging.info("In text ingest, UTF-8") # try: # # NOTE: slightly different method for .txt files, no need for download. It's part of the 'body' # response = self.s3_client.get_object(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path) # text = response['Body'].read().decode('utf-8') -# print("UTF-8 text to ignest (from s3)", text) +# logging.info("UTF-8 text to ignest (from s3)", text) # text = [text] # metadatas: List[Dict[str, Any]] = [{ @@ -897,14 +897,14 @@ # 'url': '', # 'base_url': '', # }] -# print("Prior to ingest", metadatas) +# logging.info("Prior to ingest", metadatas) # success_or_failure = self.split_and_upload(texts=text, metadatas=metadatas) # return success_or_failure # except Exception as e: # err = f"❌❌ Error in (TXT ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -915,7 +915,7 @@ # try: # with NamedTemporaryFile() as tmpfile: # # download from S3 into pdf_tmpfile -# #print("in ingest PPTX") +# #logging.info("in ingest PPTX") # self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile) # loader = UnstructuredPowerPointLoader(tmpfile.name) @@ -938,7 +938,7 @@ # except Exception as e: # err = f"❌❌ Error in (PPTX ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n", traceback.format_exc( # ) -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return str(err) @@ -979,7 +979,7 @@ # return "Success" # except Exception as e: # err = f"❌❌ Error in (GITHUB ingest): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -1004,8 +1004,8 @@ # 'base_url': metadatas[0].get('base_url', None), # }) -# print(f"In split and upload. Metadatas: {metadatas}") -# print(f"Texts: {texts}") +# logging.info(f"In split and upload. Metadatas: {metadatas}") +# logging.info(f"Texts: {texts}") # assert len(texts) == len( # metadatas # ), f'must have equal number of text strings and metadata dicts. len(texts) is {len(texts)}. len(metadatas) is {len(metadatas)}' @@ -1104,11 +1104,11 @@ # 'url': metadatas[0].get('url', None), # 'base_url': metadatas[0].get('base_url', None), # }) -# print("successful END OF split_and_upload") +# logging.info("successful END OF split_and_upload") # return "Success" # except Exception as e: # err: str = f"ERROR IN split_and_upload(): Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -1160,26 +1160,26 @@ # current_whole_text += text['input'] # if supabase_whole_text == current_whole_text: # matches the previous file -# print(f"Duplicate ingested! 📄 s3_path: {filename}.") +# logging.info(f"Duplicate ingested! 📄 s3_path: {filename}.") # return True # else: # the file is updated -# print(f"Updated file detected! Same filename, new contents. 📄 s3_path: {filename}") +# logging.info(f"Updated file detected! Same filename, new contents. 📄 s3_path: {filename}") # # call the delete function on older docs # for content in supabase_contents: -# print("older s3_path to be deleted: ", content['s3_path']) +# logging.info("older s3_path to be deleted: ", content['s3_path']) # delete_status = self.delete_data(course_name, content['s3_path'], '') -# print("delete_status: ", delete_status) +# logging.info("delete_status: ", delete_status) # return False # else: # filename does not already exist in Supabase, so its a brand new file -# print(f"NOT a duplicate! 📄s3_path: {filename}") +# logging.info(f"NOT a duplicate! 📄s3_path: {filename}") # return False # def delete_data(self, course_name: str, s3_path: str, source_url: str): # """Delete file from S3, Qdrant, and Supabase.""" -# print(f"Deleting {s3_path} from S3, Qdrant, and Supabase for course {course_name}") +# logging.info(f"Deleting {s3_path} from S3, Qdrant, and Supabase for course {course_name}") # # add delete from doc map logic here # try: # # Delete file from S3 @@ -1190,7 +1190,7 @@ # try: # self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path) # except Exception as e: -# print("Error in deleting file from s3:", e) +# logging.info("Error in deleting file from s3:", e) # sentry_sdk.capture_exception(e) # # Delete from Qdrant # # docs for nested keys: https://qdrant.tech/documentation/concepts/filtering/#nested-key @@ -1211,7 +1211,7 @@ # # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 # pass # else: -# print("Error in deleting file from Qdrant:", e) +# logging.info("Error in deleting file from Qdrant:", e) # sentry_sdk.capture_exception(e) # try: # # delete from Nomic @@ -1227,14 +1227,14 @@ # # delete from Nomic # delete_from_document_map(course_name, nomic_ids_to_delete) # except Exception as e: -# print("Error in deleting file from Nomic:", e) +# logging.info("Error in deleting file from Nomic:", e) # sentry_sdk.capture_exception(e) # try: # self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('s3_path', s3_path).eq( # 'course_name', course_name).execute() # except Exception as e: -# print("Error in deleting file from supabase:", e) +# logging.info("Error in deleting file from supabase:", e) # sentry_sdk.capture_exception(e) # # Delete files by their URL identifier @@ -1256,7 +1256,7 @@ # # https://github.com/qdrant/qdrant/issues/3654#issuecomment-1955074525 # pass # else: -# print("Error in deleting file from Qdrant:", e) +# logging.info("Error in deleting file from Qdrant:", e) # sentry_sdk.capture_exception(e) # try: # # delete from Nomic @@ -1271,7 +1271,7 @@ # # delete from Nomic # delete_from_document_map(course_name, nomic_ids_to_delete) # except Exception as e: -# print("Error in deleting file from Nomic:", e) +# logging.info("Error in deleting file from Nomic:", e) # sentry_sdk.capture_exception(e) # try: @@ -1279,14 +1279,14 @@ # self.supabase_client.from_(os.environ['SUPABASE_DOCUMENTS_TABLE']).delete().eq('url', source_url).eq( # 'course_name', course_name).execute() # except Exception as e: -# print("Error in deleting file from supabase:", e) +# logging.info("Error in deleting file from supabase:", e) # sentry_sdk.capture_exception(e) # # Delete from Supabase # return "Success" # except Exception as e: # err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore -# print(err) +# logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -1320,11 +1320,11 @@ # # if s3_paths is None: # # return "Error: No files found in the coursera-dl directory" -# # print("starting bulk ingest") +# # logging.info("starting bulk ingest") # # start_time = time.monotonic() # # self.bulk_ingest(s3_paths, course_name) -# # print("completed bulk ingest") -# # print(f"⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") +# # logging.info("completed bulk ingest") +# # logging.info(f"⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") # # # Cleanup the coursera downloads # # shutil.rmtree(dl_results_path) @@ -1332,7 +1332,7 @@ # # return "Success" # # except Exception as e: # # err: str = f"Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore -# # print(err) +# # logging.info(err) # # return err # # def list_files_recursively(self, bucket, prefix): diff --git a/ai_ta_backend/beam/nomic_logging.py b/ai_ta_backend/beam/nomic_logging.py index 2b008e99..7396450a 100644 --- a/ai_ta_backend/beam/nomic_logging.py +++ b/ai_ta_backend/beam/nomic_logging.py @@ -33,7 +33,7 @@ # Returns: # str: success or failed # """ -# print("in create_document_map()") +# logging.info("in create_document_map()") # nomic.login(os.getenv('NOMIC_API_KEY')) # try: @@ -52,7 +52,7 @@ # return "No documents found for this course." # total_doc_count = response.count -# print("Total number of documents in Supabase: ", total_doc_count) +# logging.info("Total number of documents in Supabase: ", total_doc_count) # # minimum 20 docs needed to create map # if total_doc_count < 20: @@ -87,7 +87,7 @@ # if first_batch: # # create a new map -# print("Creating new map...") +# logging.info("Creating new map...") # project_name = NOMIC_MAP_NAME_PREFIX + course_name # index_name = course_name + "_doc_index" # topic_label_field = "text" @@ -105,14 +105,14 @@ # project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() # if project_response.data: # update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() -# print("Response from supabase: ", update_response) +# logging.info("Response from supabase: ", update_response) # else: # insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() -# print("Insert Response from supabase: ", insert_response) +# logging.info("Insert Response from supabase: ", insert_response) # else: # # append to existing map -# print("Appending data to existing map...") +# logging.info("Appending data to existing map...") # project_name = NOMIC_MAP_NAME_PREFIX + course_name # # add project lock logic here # result = append_to_map(embeddings, metadata, project_name) @@ -121,12 +121,12 @@ # last_id = int(final_df['id'].iloc[-1]) # info = {'last_uploaded_doc_id': last_id} # update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() -# print("Response from supabase: ", update_response) +# logging.info("Response from supabase: ", update_response) # # reset variables # combined_dfs = [] # doc_count = 0 -# print("Records uploaded: ", curr_total_doc_count) +# logging.info("Records uploaded: ", curr_total_doc_count) # # set first_id for next iteration # first_id = response.data[-1]['id'] + 1 @@ -151,20 +151,20 @@ # project = AtlasProject(name=project_name, add_datums_if_exists=True) # project_id = project.id # project_info = {'course_name': course_name, 'doc_map_id': project_id, 'last_uploaded_doc_id': last_id} -# print("project_info: ", project_info) +# logging.info("project_info: ", project_info) # project_response = SUPABASE_CLIENT.table("projects").select("*").eq("course_name", course_name).execute() # if project_response.data: # update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() -# print("Response from supabase: ", update_response) +# logging.info("Response from supabase: ", update_response) # else: # insert_response = SUPABASE_CLIENT.table("projects").insert(project_info).execute() -# print("Insert Response from supabase: ", insert_response) +# logging.info("Insert Response from supabase: ", insert_response) # # rebuild the map # rebuild_map(course_name, "document") # except Exception as e: -# print(e) +# logging.info(e) # sentry_sdk.capture_exception(e) # return "failed" @@ -176,7 +176,7 @@ # course_name: str # ids: list of str # """ -# print("in delete_from_document_map()") +# logging.info("in delete_from_document_map()") # try: # # check if project exists @@ -190,12 +190,12 @@ # project = AtlasProject(project_id=project_id, add_datums_if_exists=True) # # delete the ids from Nomic -# print("Deleting point from document map:", project.delete_data(ids)) +# logging.info("Deleting point from document map:", project.delete_data(ids)) # with project.wait_for_project_lock(): # project.rebuild_maps() # return "success" # except Exception as e: -# print(e) +# logging.info(e) # sentry_sdk.capture_exception(e) # return "Error in deleting from document map: {e}" @@ -206,7 +206,7 @@ # Args: # data: dict - the response data from Supabase insertion # """ -# print("in add_to_document_map()") +# logging.info("in add_to_document_map()") # try: # # check if map exists @@ -233,9 +233,9 @@ # return "Skipping Nomic logging because project is locked." # # fetch count of records greater than last_uploaded_doc_id -# print("last uploaded doc id: ", last_uploaded_doc_id) +# logging.info("last uploaded doc id: ", last_uploaded_doc_id) # response = SUPABASE_CLIENT.table("documents").select("id", count="exact").eq("course_name", course_name).gt("id", last_uploaded_doc_id).execute() -# print("Number of new documents: ", response.count) +# logging.info("Number of new documents: ", response.count) # total_doc_count = response.count # current_doc_count = 0 @@ -258,7 +258,7 @@ # embeddings, metadata = data_prep_for_doc_map(final_df) # # append to existing map -# print("Appending data to existing map...") +# logging.info("Appending data to existing map...") # result = append_to_map(embeddings, metadata, project_name) # if result == "success": @@ -266,12 +266,12 @@ # last_id = int(final_df['id'].iloc[-1]) # info = {'last_uploaded_doc_id': last_id} # update_response = SUPABASE_CLIENT.table("projects").update(info).eq("course_name", course_name).execute() -# print("Response from supabase: ", update_response) +# logging.info("Response from supabase: ", update_response) # # reset variables # combined_dfs = [] # doc_count = 0 -# print("Records uploaded: ", current_doc_count) +# logging.info("Records uploaded: ", current_doc_count) # # set first_id for next iteration # first_id = response.data[-1]['id'] + 1 @@ -288,11 +288,11 @@ # last_id = int(final_df['id'].iloc[-1]) # project_info = {'last_uploaded_doc_id': last_id} # update_response = SUPABASE_CLIENT.table("projects").update(project_info).eq("course_name", course_name).execute() -# print("Response from supabase: ", update_response) +# logging.info("Response from supabase: ", update_response) # return "success" # except Exception as e: -# print(e) +# logging.info(e) # return "failed" # def create_map(embeddings, metadata, map_name, index_name, topic_label_field, colorable_fields): @@ -319,7 +319,7 @@ # project.create_index(name=index_name, build_topic_model=True) # return "success" # except Exception as e: -# print(e) +# logging.info(e) # return "Error in creating map: {e}" # def append_to_map(embeddings, metadata, map_name): @@ -338,7 +338,7 @@ # project.add_embeddings(embeddings=embeddings, data=metadata) # return "success" # except Exception as e: -# print(e) +# logging.info(e) # return "Error in appending to map: {e}" # def data_prep_for_doc_map(df: pd.DataFrame): @@ -350,7 +350,7 @@ # embeddings: np.array of embeddings # metadata: pd.DataFrame of metadata # """ -# print("in data_prep_for_doc_map()") +# logging.info("in data_prep_for_doc_map()") # metadata = [] # embeddings = [] @@ -387,11 +387,11 @@ # texts.append(text_row) # embeddings_np = np.array(embeddings, dtype=object) -# print("Shape of embeddings: ", embeddings_np.shape) +# logging.info("Shape of embeddings: ", embeddings_np.shape) # # check dimension if embeddings_np is (n, 1536) # if len(embeddings_np.shape) < 2: -# print("Creating new embeddings...") +# logging.info("Creating new embeddings...") # embeddings_model = OpenAIEmbeddings(openai_api_type="openai", # openai_api_base="https://api.openai.com/v1/", @@ -407,7 +407,7 @@ # """ # This function rebuilds a given map in Nomic. # """ -# print("in rebuild_map()") +# logging.info("in rebuild_map()") # nomic.login(os.getenv('NOMIC_API_KEY')) # if map_type.lower() == 'document': # NOMIC_MAP_NAME_PREFIX = 'Document Map for ' @@ -423,7 +423,7 @@ # project.rebuild_maps() # return "success" # except Exception as e: -# print(e) +# logging.info(e) # sentry_sdk.capture_exception(e) # return "Error in rebuilding map: {e}" diff --git a/ai_ta_backend/database/aws.py b/ai_ta_backend/database/aws.py index 58fb1bb3..047974cd 100644 --- a/ai_ta_backend/database/aws.py +++ b/ai_ta_backend/database/aws.py @@ -1,3 +1,4 @@ +import logging import os import boto3 @@ -9,14 +10,14 @@ class AWSStorage(): @inject def __init__(self): if all(os.getenv(key) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]): - print("Using AWS for storage") + logging.info("Using AWS for storage") self.s3_client = boto3.client( 's3', aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), ) elif all(os.getenv(key) for key in ["MINIO_ACCESS_KEY", "MINIO_SECRET_KEY", "MINIO_URL"]): - print("Using Minio for storage") + logging.info("Using Minio for storage") self.s3_client = boto3.client('s3', endpoint_url=os.getenv('MINIO_URL'), aws_access_key_id=os.getenv('MINIO_ACCESS_KEY'), diff --git a/ai_ta_backend/database/qdrant.py b/ai_ta_backend/database/qdrant.py index f2826e68..18a792ff 100644 --- a/ai_ta_backend/database/qdrant.py +++ b/ai_ta_backend/database/qdrant.py @@ -1,3 +1,4 @@ +import logging import os from typing import List @@ -22,7 +23,8 @@ def __init__(self): """ # vector DB self.qdrant_client = QdrantClient( - url=os.environ['QDRANT_URL'], + url='http://qdrant:6333', + https=False, api_key=os.environ['QDRANT_API_KEY'], timeout=20, # default is 5 seconds. Getting timeout errors w/ document groups. ) @@ -41,7 +43,7 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q # Filter for the must_conditions myfilter = models.Filter(must=must_conditions) - print(f"Filter: {myfilter}") + logging.info(f"Qdrant serach Filter: {myfilter}") # Search the vector database search_results = self.qdrant_client.search( @@ -52,6 +54,7 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q limit=top_n, # Return n closest points # In a system with high disk latency, the re-scoring step may become a bottleneck: https://qdrant.tech/documentation/guides/quantization/ search_params=models.SearchParams(quantization=models.QuantizationSearchParams(rescore=False))) + return search_results def _create_search_conditions(self, course_name, doc_groups: List[str]): diff --git a/ai_ta_backend/database/supabase.py b/ai_ta_backend/database/supabase.py index 38f7252d..5a1cf39c 100644 --- a/ai_ta_backend/database/supabase.py +++ b/ai_ta_backend/database/supabase.py @@ -1,3 +1,4 @@ +import logging import os from injector import inject @@ -39,14 +40,14 @@ def getProjectsMapForCourse(self, course_name: str): def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: str, table_name: str): if from_date != '' and to_date != '': # query between the dates - print("from_date and to_date") + logging.info("from_date and to_date") response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).gte( 'created_at', from_date).lte('created_at', to_date).order('id', desc=False).execute() elif from_date != '' and to_date == '': # query from from_date to now - print("only from_date") + logging.info("only from_date") response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).gte('created_at', @@ -55,7 +56,7 @@ def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: st elif from_date == '' and to_date != '': # query from beginning to to_date - print("only to_date") + logging.info("only to_date") response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).lte('created_at', @@ -64,7 +65,7 @@ def getDocumentsBetweenDates(self, course_name: str, from_date: str, to_date: st else: # query all data - print("No dates") + logging.info("No dates") response = self.supabase_client.table(table_name).select("id", count='exact').eq("course_name", course_name).order('id', desc=False).execute() return response diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 44ddc874..56470ea7 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -1,6 +1,5 @@ import logging import os -import sys import time from typing import List @@ -40,11 +39,6 @@ from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.service.workflow_service import WorkflowService -# Make docker log our prints() -- Set PYTHONUNBUFFERED to ensure no output buffering -os.environ['PYTHONUNBUFFERED'] = '1' -sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 1) -sys.stderr = os.fdopen(sys.stderr.fileno(), 'w', 1) - app = Flask(__name__) CORS(app) executor = Executor(app) @@ -113,6 +107,9 @@ def getTopContexts(service: RetrievalService) -> Response: token_limit: int = data.get('token_limit', 3000) doc_groups: List[str] = data.get('doc_groups', []) + logging.info(f"QDRANT URL {os.environ['QDRANT_URL']}") + logging.info(f"QDRANT_API_KEY {os.environ['QDRANT_API_KEY']}") + if search_query == '' or course_name == '': # proper web error "400 Bad request" abort( @@ -132,7 +129,7 @@ def getTopContexts(service: RetrievalService) -> Response: def getAll(service: RetrievalService) -> Response: """Get all course materials based on the course_name """ - print("In getAll()") + logging.info("In getAll()") course_name: List[str] | str = request.args.get('course_name', default='', type=str) if course_name == '': @@ -167,8 +164,8 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface): start_time = time.monotonic() # background execution of tasks!! flaskExecutor.submit(service.delete_data, course_name, s3_path, source_url) - print(f"From {course_name}, deleted file: {s3_path}") - print(f"⏰ Runtime of FULL delete func: {(time.monotonic() - start_time):.2f} seconds") + logging.info(f"From {course_name}, deleted file: {s3_path}") + logging.info(f"⏰ Runtime of FULL delete func: {(time.monotonic() - start_time):.2f} seconds") # we need instant return. Delets are "best effort" assume always successful... sigh :( response = jsonify({"outcome": 'success'}) response.headers.add('Access-Control-Allow-Origin', '*') @@ -185,7 +182,7 @@ def nomic_map(service: NomicService): abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") map_id = service.get_nomic_map(course_name, map_type) - print("nomic map\n", map_id) + logging.info("nomic map\n", map_id) response = jsonify(map_id) response.headers.add('Access-Control-Allow-Origin', '*') @@ -251,7 +248,7 @@ def logToNomic(service: NomicService, flaskExecutor: ExecutorInterface): description= f"Missing one or more required parameters: 'course_name' and 'conversation' must be provided. Course name: `{course_name}`, Conversation: `{conversation}`" ) - print(f"In /onResponseCompletion for course: {course_name}") + logging.info(f"In /onResponseCompletion for course: {course_name}") # background execution of tasks!! #response = flaskExecutor.submit(service.log_convo_to_nomic, course_name, data) @@ -272,7 +269,7 @@ def export_convo_history(service: ExportService): abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") export_status = service.export_convo_history_json(course_name, from_date, to_date) - print("EXPORT FILE LINKS: ", export_status) + logging.info("EXPORT FILE LINKS: ", export_status) if export_status['response'] == "No data found between the given dates.": response = Response(status=204) @@ -303,7 +300,7 @@ def export_conversations_custom(service: ExportService): abort(400, description="Missing required parameter: 'course_name' and 'destination_email_ids' must be provided.") export_status = service.export_conversations(course_name, from_date, to_date, emails) - print("EXPORT FILE LINKS: ", export_status) + logging.info("EXPORT FILE LINKS: ", export_status) if export_status['response'] == "No data found between the given dates.": response = Response(status=204) @@ -333,7 +330,7 @@ def exportDocuments(service: ExportService): abort(400, description=f"Missing required parameter: 'course_name' must be provided. Course name: `{course_name}`") export_status = service.export_documents_json(course_name, from_date, to_date) - print("EXPORT FILE LINKS: ", export_status) + logging.info("EXPORT FILE LINKS: ", export_status) if export_status['response'] == "No data found between the given dates.": response = Response(status=204) @@ -393,9 +390,9 @@ def get_all_workflows(service: WorkflowService) -> Response: pagination = request.args.get('pagination', default=True, type=bool) active = request.args.get('active', default=False, type=bool) name = request.args.get('workflow_name', default='', type=str) - print(request.args) + logging.info(request.args) - print("In get_all_workflows.. api_key: ", api_key) + logging.info("In get_all_workflows.. api_key: ", api_key) # if no API Key, return empty set. # if api_key == '': @@ -409,10 +406,10 @@ def get_all_workflows(service: WorkflowService) -> Response: return response except Exception as e: if "unauthorized" in str(e).lower(): - print("Unauthorized error in get_all_workflows: ", e) + logging.info("Unauthorized error in get_all_workflows: ", e) abort(401, description=f"Unauthorized: 'api_key' is invalid. Search query: `{api_key}`") else: - print("Error in get_all_workflows: ", e) + logging.info("Error in get_all_workflows: ", e) abort(500, description=f"Failed to fetch n8n workflows: {e}") @@ -426,14 +423,14 @@ def switch_workflow(service: WorkflowService) -> Response: activate = request.args.get('activate', default='', type=str) id = request.args.get('id', default='', type=str) - print(request.args) + logging.info(request.args) if api_key == '': # proper web error "400 Bad request" abort(400, description=f"Missing N8N API_KEY: 'api_key' must be provided. Search query: `{api_key}`") try: - print("activation!!!!!!!!!!!", activate) + logging.info("activation!!!!!!!!!!!", activate) response = service.switch_workflow(id, api_key, activate) response = jsonify(response) response.headers.add('Access-Control-Allow-Origin', '*') @@ -455,7 +452,7 @@ def run_flow(service: WorkflowService) -> Response: name = request.json.get('name', '') data = request.json.get('data', '') - print("Got /run_flow request:", request.json) + logging.info("Got /run_flow request:", request.json) if api_key == '': # proper web error "400 Bad request" @@ -499,7 +496,7 @@ def configure(binder: Binder) -> None: db.init_app(app) db.create_all() binder.bind(SQLAlchemyDatabase, to=SQLAlchemyDatabase(db), scope=SingletonScope) - print("Bound to SQL DB!") + logging.info("Bound to SQL DB!") sql_bound = True break @@ -507,6 +504,10 @@ def configure(binder: Binder) -> None: if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any( os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): logging.info("Binding to Qdrant database") + + logging.info(f"Qdrant Collection Name: {os.environ['QDRANT_COLLECTION_NAME']}") + logging.info(f"Qdrant URL: {os.environ['QDRANT_URL']}") + logging.info(f"Qdrant API Key: {os.environ['QDRANT_API_KEY']}") binder.bind(VectorDatabase, to=VectorDatabase, scope=SingletonScope) vector_bound = True diff --git a/ai_ta_backend/modal/pest_detection.py b/ai_ta_backend/modal/pest_detection.py index 8f514efa..0a528acd 100644 --- a/ai_ta_backend/modal/pest_detection.py +++ b/ai_ta_backend/modal/pest_detection.py @@ -15,6 +15,7 @@ """ import inspect import json +import logging import os from tempfile import NamedTemporaryFile import traceback @@ -88,20 +89,20 @@ async def predict(self, request: Request): This used to use the method decorator Run the pest detection plugin on an image. """ - print("Inside predict() endpoint") + logging.info("Inside predict() endpoint") input = await request.json() - print("Request.json(): ", input) + logging.info("Request.json(): ", input) image_urls = input.get('image_urls', []) if image_urls and isinstance(image_urls, str): image_urls = json.loads(image_urls) - print(f"Final image URLs: {image_urls}") + logging.info(f"Final image URLs: {image_urls}") try: # Run the plugin annotated_images = self._detect_pests(image_urls) - print(f"annotated_images found: {len(annotated_images)}") + logging.info(f"annotated_images found: {len(annotated_images)}") results = [] # Generate a unique ID for the request unique_id = uuid.uuid4() @@ -130,7 +131,7 @@ async def predict(self, request: Request): return results except Exception as e: err = f"❌❌ Error in (pest_detection): `{inspect.currentframe().f_code.co_name}`: {e}\nTraceback:\n{traceback.format_exc()}" # type: ignore - print(err) + logging.info(err) # sentry_sdk.capture_exception(e) return err diff --git a/ai_ta_backend/service/export_service.py b/ai_ta_backend/service/export_service.py index c998d001..a945453d 100644 --- a/ai_ta_backend/service/export_service.py +++ b/ai_ta_backend/service/export_service.py @@ -1,5 +1,6 @@ from concurrent.futures import ProcessPoolExecutor import json +import logging import os import uuid import zipfile @@ -55,16 +56,16 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): first_id = int(str(response.data[0].id)) last_id = int(str(response.data[-1].id)) - print("total_doc_count: ", total_doc_count) - print("first_id: ", first_id) - print("last_id: ", last_id) + logging.info("total_doc_count: ", total_doc_count) + logging.info("first_id: ", first_id) + logging.info("last_id: ", last_id) curr_doc_count = 0 filename = course_name + '_' + str(uuid.uuid4()) + '_documents.jsonl' file_path = os.path.join(os.getcwd(), filename) while curr_doc_count < total_doc_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) response = self.sql.getDocsForIdsGte(course_name, first_id) df = pd.DataFrame(response.data) @@ -91,7 +92,7 @@ def export_documents_json(self, course_name: str, from_date='', to_date=''): os.remove(file_path) return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return {"response": "Error downloading file."} else: @@ -105,7 +106,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): from_date (str, optional): The start date for the data export. Defaults to ''. to_date (str, optional): The end date for the data export. Defaults to ''. """ - print("Exporting conversation history to json file...") + logging.info("Exporting conversation history to json file...") response = self.sql.getConversationsBetweenDates(course_name, from_date, to_date) @@ -120,7 +121,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): # Fetch data if response.count > 0: - print("id count greater than zero") + logging.info("id count greater than zero") first_id = int(str(response.data[0].id)) last_id = int(str(response.data[-1].id)) total_count = response.count @@ -130,7 +131,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): curr_count = 0 # Fetch data in batches of 25 from first_id to last_id while curr_count < total_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) response = self.sql.getAllConversationsBetweenIds(course_name, first_id, last_id) # Convert to pandas dataframe df = pd.DataFrame(response.data) @@ -145,7 +146,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): # Update first_id if len(response.data) > 0: first_id = int(str(response.data[-1].id)) + 1 - print("updated first_id: ", first_id) + logging.info("updated first_id: ", first_id) # Download file try: @@ -159,7 +160,7 @@ def export_convo_history_json(self, course_name: str, from_date='', to_date=''): return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return {"response": "Error downloading file!"} else: @@ -169,7 +170,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e """ Another function for exporting convos, emails are passed as a string. """ - print("Exporting conversation history to json file...") + logging.info("Exporting conversation history to json file...") response = self.sql.getConversationsBetweenDates(course_name, from_date, to_date) @@ -184,7 +185,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e # Fetch data if response.count > 0: - print("id count greater than zero") + logging.info("id count greater than zero") first_id = int(str(response.data[0].id)) last_id = int(str(response.data[-1].id)) total_count = response.count @@ -194,7 +195,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e curr_count = 0 # Fetch data in batches of 25 from first_id to last_id while curr_count < total_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) response = self.sql.getAllConversationsBetweenIds(course_name, first_id, last_id) # Convert to pandas dataframe df = pd.DataFrame(response.data) @@ -209,7 +210,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e # Update first_id if len(response.data) > 0: first_id = int(str(response.data[-1].id)) + 1 - print("updated first_id: ", first_id) + logging.info("updated first_id: ", first_id) # Download file try: @@ -223,7 +224,7 @@ def export_conversations(self, course_name: str, from_date: str, to_date: str, e return {"response": (zip_file_path, zip_filename, os.getcwd())} except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return {"response": "Error downloading file!"} else: @@ -250,8 +251,8 @@ def export_data_in_bg(response, download_type, course_name, s3_path): total_doc_count = response.count first_id = response.data[0]['id'] - print("total_doc_count: ", total_doc_count) - print("pre-defined s3_path: ", s3_path) + logging.info("total_doc_count: ", total_doc_count) + logging.info("pre-defined s3_path: ", s3_path) curr_doc_count = 0 filename = s3_path.split('/')[-1].split('.')[0] + '.jsonl' @@ -259,7 +260,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): # download data in batches of 100 while curr_doc_count < total_doc_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) if download_type == "documents": response = sql.getAllDocumentsForDownload(course_name, first_id) else: @@ -283,7 +284,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): with zipfile.ZipFile(zip_file_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: zipf.write(file_path, filename) - print("zip file created: ", zip_file_path) + logging.info("zip file created: ", zip_file_path) try: # upload to S3 @@ -296,7 +297,7 @@ def export_data_in_bg(response, download_type, course_name, s3_path): os.remove(file_path) os.remove(zip_file_path) - print("file uploaded to s3: ", s3_file) + logging.info("file uploaded to s3: ", s3_file) # generate presigned URL s3_url = s3.generatePresignedUrl('get_object', os.environ['S3_BUCKET_NAME'], s3_path, 172800) @@ -319,8 +320,8 @@ def export_data_in_bg(response, download_type, course_name, s3_path): # add course owner email to admin_emails admin_emails.append(course_metadata['course_owner']) admin_emails = list(set(admin_emails)) - print("admin_emails: ", admin_emails) - print("bcc_emails: ", bcc_emails) + logging.info("admin_emails: ", admin_emails) + logging.info("bcc_emails: ", bcc_emails) # add a check for emails, don't send email if no admin emails if len(admin_emails) == 0: @@ -335,12 +336,12 @@ def export_data_in_bg(response, download_type, course_name, s3_path): subject = "UIUC.chat Export Complete for " + course_name body_text = "The data export for " + course_name + " is complete.\n\nYou can download the file from the following link: \n\n" + s3_url + "\n\nThis link will expire in 48 hours." email_status = send_email(subject, body_text, os.environ['EMAIL_SENDER'], admin_emails, bcc_emails) - print("email_status: ", email_status) + logging.info("email_status: ", email_status) return "File uploaded to S3. Email sent to admins." except Exception as e: - print(e) + logging.info(e) return "Error: " + str(e) @@ -362,8 +363,8 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai total_doc_count = response.count first_id = response.data[0]['id'] - print("total_doc_count: ", total_doc_count) - print("pre-defined s3_path: ", s3_path) + logging.info("total_doc_count: ", total_doc_count) + logging.info("pre-defined s3_path: ", s3_path) curr_doc_count = 0 filename = s3_path.split('/')[-1].split('.')[0] + '.jsonl' @@ -371,7 +372,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai # download data in batches of 100 while curr_doc_count < total_doc_count: - print("Fetching data from id: ", first_id) + logging.info("Fetching data from id: ", first_id) if download_type == "documents": response = sql.getAllDocumentsForDownload(course_name, first_id) else: @@ -395,7 +396,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai with zipfile.ZipFile(zip_file_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: zipf.write(file_path, filename) - print("zip file created: ", zip_file_path) + logging.info("zip file created: ", zip_file_path) try: # upload to S3 @@ -408,7 +409,7 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai os.remove(file_path) os.remove(zip_file_path) - print("file uploaded to s3: ", s3_file) + logging.info("file uploaded to s3: ", s3_file) # generate presigned URL s3_url = s3.generatePresignedUrl('get_object', os.environ['S3_BUCKET_NAME'], s3_path, 172800) @@ -416,8 +417,8 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai admin_emails = emails bcc_emails = [] - print("admin_emails: ", admin_emails) - print("bcc_emails: ", bcc_emails) + logging.info("admin_emails: ", admin_emails) + logging.info("bcc_emails: ", bcc_emails) # add a check for emails, don't send email if no admin emails if len(admin_emails) == 0: @@ -432,10 +433,10 @@ def export_data_in_bg_emails(response, download_type, course_name, s3_path, emai subject = "UIUC.chat Export Complete for " + course_name body_text = "The data export for " + course_name + " is complete.\n\nYou can download the file from the following link: \n\n" + s3_url + "\n\nThis link will expire in 48 hours." email_status = send_email(subject, body_text, os.environ['EMAIL_SENDER'], admin_emails, bcc_emails) - print("email_status: ", email_status) + logging.info("email_status: ", email_status) return "File uploaded to S3. Email sent to admins." except Exception as e: - print(e) + logging.info(e) return "Error: " + str(e) diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index c7011c88..b97bfd1d 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -1,4 +1,5 @@ import datetime +import logging import os import time @@ -49,14 +50,14 @@ def get_nomic_map(self, course_name: str, type: str): project = atlas.AtlasProject(name=project_name, add_datums_if_exists=True) map = project.get_map(project_name) - print(f"⏰ Nomic Full Map Retrieval: {(time.monotonic() - start_time):.2f} seconds") + logging.info(f"⏰ Nomic Full Map Retrieval: {(time.monotonic() - start_time):.2f} seconds") return {"map_id": f"iframe{map.id}", "map_link": map.map_link} except Exception as e: # Error: ValueError: You must specify a unique_id_field when creating a new project. if str(e) == 'You must specify a unique_id_field when creating a new project.': # type: ignore - print("Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", e) + logging.info("Nomic map does not exist yet, probably because you have less than 20 queries/documents on your project: ", e) else: - print("ERROR in get_nomic_map():", e) + logging.info("ERROR in get_nomic_map():", e) self.sentry.capture_exception(e) return {"map_id": None, "map_link": None} @@ -72,16 +73,16 @@ def log_to_conversation_map(self, course_name: str, conversation): try: # check if map exists response = self.sql.getConvoMapFromProjects(course_name) - print("Response from supabase: ", response.data) + logging.info("Response from supabase: ", response.data) # entry not present in projects table if not response.data: - print("Map does not exist for this course. Redirecting to map creation...") + logging.info("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) # entry present for doc map, but not convo map elif response.data[0].convo_map_id is not None: - print("Map does not exist for this course. Redirecting to map creation...") + logging.info("Map does not exist for this course. Redirecting to map creation...") return self.create_conversation_map(course_name) project_id = response.data[0].convo_map_id @@ -95,7 +96,7 @@ def log_to_conversation_map(self, course_name: str, conversation): # fetch count of conversations since last upload response = self.sql.getCountFromLLMConvoMonitor(course_name, last_id=last_uploaded_convo_id) total_convo_count = response.count - print("Total number of unlogged conversations in Supabase: ", total_convo_count) + logging.info("Total number of unlogged conversations in Supabase: ", total_convo_count) if total_convo_count == 0: # log to an existing conversation @@ -109,14 +110,14 @@ def log_to_conversation_map(self, course_name: str, conversation): while current_convo_count < total_convo_count: response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) - print("Response count: ", len(response.data)) + logging.info("Response count: ", len(response.data)) if len(response.data) == 0: break df = pd.DataFrame(response.data) combined_dfs.append(df) current_convo_count += len(response.data) convo_count += len(response.data) - print(current_convo_count) + logging.info(current_convo_count) if convo_count >= 500: # concat all dfs from the combined_dfs list @@ -124,24 +125,24 @@ def log_to_conversation_map(self, course_name: str, conversation): # prep data for nomic upload embeddings, metadata = self.data_prep_for_convo_map(final_df) # append to existing map - print("Appending data to existing map...") + logging.info("Appending data to existing map...") result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) if result == "success": last_id = int(final_df['id'].iloc[-1]) project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) # reset variables combined_dfs = [] convo_count = 0 - print("Records uploaded: ", current_convo_count) + logging.info("Records uploaded: ", current_convo_count) # set first_id for next iteration first_id = int(str(response.data[-1].id)) + 1 # upload last set of convos if convo_count > 0: - print("Uploading last set of conversations...") + logging.info("Uploading last set of conversations...") final_df = pd.concat(combined_dfs, ignore_index=True) embeddings, metadata = self.data_prep_for_convo_map(final_df) result = self.append_to_map(embeddings, metadata, NOMIC_MAP_NAME_PREFIX + course_name) @@ -149,14 +150,14 @@ def log_to_conversation_map(self, course_name: str, conversation): last_id = int(final_df['id'].iloc[-1]) project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) # rebuild the map self.rebuild_map(course_name, "conversation") return "success" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in logging to conversation map: {e}" @@ -164,7 +165,7 @@ def log_to_existing_conversation(self, course_name: str, conversation): """ This function logs follow-up questions to existing conversations in the map. """ - print(f"in log_to_existing_conversation() for course: {course_name}") + logging.info(f"in log_to_existing_conversation() for course: {course_name}") try: conversation_id = conversation['id'] @@ -211,23 +212,23 @@ def log_to_existing_conversation(self, course_name: str, conversation): metadata = pd.DataFrame(uploaded_data) embeddings = np.array(embeddings) - print("Metadata shape:", metadata.shape) - print("Embeddings shape:", embeddings.shape) + logging.info("Metadata shape:", metadata.shape) + logging.info("Embeddings shape:", embeddings.shape) # deleting existing map - print("Deleting point from nomic:", project.delete_data([prev_id])) + logging.info("Deleting point from nomic:", project.delete_data([prev_id])) # re-build map to reflect deletion project.rebuild_maps() # re-insert updated conversation result = self.append_to_map(embeddings, metadata, project_name) - print("Result of appending to existing map:", result) + logging.info("Result of appending to existing map:", result) return "success" except Exception as e: - print("Error in log_to_existing_conversation():", e) + logging.info("Error in log_to_existing_conversation():", e) self.sentry.capture_exception(e) return "Error in logging to existing conversation: {e}" @@ -240,7 +241,7 @@ def create_conversation_map(self, course_name: str): try: # check if map exists response = self.sql.getConvoMapFromProjects(course_name) - print("Response from supabase: ", response.data) + logging.info("Response from supabase: ", response.data) if response.data: if response.data[0].convo_map_id is not None: return "Map already exists for this course." @@ -256,7 +257,7 @@ def create_conversation_map(self, course_name: str): # if >20, iteratively fetch records in batches of 100 total_convo_count = response.count - print("Total number of conversations in Supabase: ", total_convo_count) + logging.info("Total number of conversations in Supabase: ", total_convo_count) first_id = int(str(response.data[0].id)) - 1 combined_dfs = [] @@ -268,14 +269,14 @@ def create_conversation_map(self, course_name: str): # iteratively query in batches of 50 while current_convo_count < total_convo_count: response = self.sql.getAllConversationsBetweenIds(course_name, first_id, 0, 100) - print("Response count: ", len(response.data)) + logging.info("Response count: ", len(response.data)) if len(response.data) == 0: break df = pd.DataFrame(response.data) combined_dfs.append(df) current_convo_count += len(response.data) convo_count += len(response.data) - print(current_convo_count) + logging.info(current_convo_count) if convo_count >= 500: # concat all dfs from the combined_dfs list @@ -285,7 +286,7 @@ def create_conversation_map(self, course_name: str): if first_batch: # create a new map - print("Creating new map...") + logging.info("Creating new map...") index_name = course_name + "_convo_index" topic_label_field = "first_query" colorable_fields = ["user_email", "first_query", "conversation_id", "created_at"] @@ -305,35 +306,35 @@ def create_conversation_map(self, course_name: str): project_response = self.sql.updateProjects(course_name, project_info) else: project_response = self.sql.insertProjectInfo(project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) else: # append to existing map - print("Appending data to existing map...") + logging.info("Appending data to existing map...") project = AtlasProject(name=project_name, add_datums_if_exists=True) result = self.append_to_map(embeddings, metadata, project_name) if result == "success": - print("map append successful") + logging.info("map append successful") last_id = int(final_df['id'].iloc[-1]) project_info = {'last_uploaded_convo_id': last_id} project_response = self.sql.updateProjects(course_name, project_info) - print("Update response from supabase: ", project_response) + logging.info("Update response from supabase: ", project_response) # reset variables combined_dfs = [] convo_count = 0 - print("Records uploaded: ", current_convo_count) + logging.info("Records uploaded: ", current_convo_count) # set first_id for next iteration try: - print("response: ", response.data[-1].id) + logging.info("response: ", response.data[-1].id) except Exception as e: - print("response: ", response.data) + logging.info("response: ", response.data) first_id = int(str(response.data[-1].id)) + 1 - print("Convo count: ", convo_count) + logging.info("Convo count: ", convo_count) # upload last set of convos if convo_count > 0: - print("Uploading last set of conversations...") + logging.info("Uploading last set of conversations...") final_df = pd.concat(combined_dfs, ignore_index=True) embeddings, metadata = self.data_prep_for_convo_map(final_df) if first_batch: @@ -345,29 +346,29 @@ def create_conversation_map(self, course_name: str): else: # append to map - print("in map append") + logging.info("in map append") result = self.append_to_map(embeddings, metadata, project_name) if result == "success": - print("last map append successful") + logging.info("last map append successful") last_id = int(final_df['id'].iloc[-1]) project = AtlasProject(name=project_name, add_datums_if_exists=True) project_id = project.id project_info = {'course_name': course_name, 'convo_map_id': project_id, 'last_uploaded_convo_id': last_id} - print("Project info: ", project_info) + logging.info("Project info: ", project_info) # if entry already exists, update it projects_record = self.sql.getConvoMapFromProjects(course_name) if projects_record.data: project_response = self.sql.updateProjects(course_name, project_info) else: project_response = self.sql.insertProjectInfo(project_info) - print("Response from supabase: ", project_response) + logging.info("Response from supabase: ", project_response) # rebuild the map self.rebuild_map(course_name, "conversation") return "success" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in creating conversation map:" + str(e) @@ -377,7 +378,7 @@ def rebuild_map(self, course_name: str, map_type: str): """ This function rebuilds a given map in Nomic. """ - print("in rebuild_map()") + logging.info("in rebuild_map()") nomic.login(os.getenv('NOMIC_API_KEY')) if map_type.lower() == 'document': @@ -394,7 +395,7 @@ def rebuild_map(self, course_name: str, map_type: str): project.rebuild_maps() return "success" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in rebuilding map: {e}" @@ -410,7 +411,7 @@ def create_map(self, embeddings, metadata, map_name, index_name, topic_label_fie colorable_fields: list of str """ nomic.login(os.environ['NOMIC_API_KEY']) - print("in create_map()") + logging.info("in create_map()") try: project = atlas.map_embeddings(embeddings=embeddings, data=metadata, @@ -423,7 +424,7 @@ def create_map(self, embeddings, metadata, map_name, index_name, topic_label_fie project.create_index(index_name, build_topic_model=True) return "success" except Exception as e: - print(e) + logging.info(e) return "Error in creating map: {e}" def append_to_map(self, embeddings, metadata, map_name): @@ -441,7 +442,7 @@ def append_to_map(self, embeddings, metadata, map_name): project.add_embeddings(embeddings=embeddings, data=metadata) return "success" except Exception as e: - print(e) + logging.info(e) return "Error in appending to map: {e}" def data_prep_for_convo_map(self, df: pd.DataFrame): @@ -453,7 +454,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): embeddings: np.array of embeddings metadata: pd.DataFrame of metadata """ - print("in data_prep_for_convo_map()") + logging.info("in data_prep_for_convo_map()") try: metadata = [] @@ -478,7 +479,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): if isinstance(messages[0]['content'], list): if 'text' in messages[0]['content'][0]: first_message = messages[0]['content'][0]['text'] - #print("First message:", first_message) + #logging.info("First message:", first_message) else: first_message = messages[0]['content'] user_queries.append(first_message) @@ -510,7 +511,7 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): "created_at": created_at, "modified_at": current_time } - #print("Metadata row:", meta_row) + #logging.info("Metadata row:", meta_row) metadata.append(meta_row) embeddings_model = OpenAIEmbeddings(openai_api_type="openai", @@ -521,12 +522,12 @@ def data_prep_for_convo_map(self, df: pd.DataFrame): metadata = pd.DataFrame(metadata) embeddings = np.array(embeddings) - print("Metadata shape:", metadata.shape) - print("Embeddings shape:", embeddings.shape) + logging.info("Metadata shape:", metadata.shape) + logging.info("Embeddings shape:", embeddings.shape) return embeddings, metadata except Exception as e: - print("Error in data_prep_for_convo_map():", e) + logging.info("Error in data_prep_for_convo_map():", e) self.sentry.capture_exception(e) return None, None @@ -538,18 +539,18 @@ def delete_from_document_map(self, project_id: str, ids: list): course_name: str ids: list of str """ - print("in delete_from_document_map()") + logging.info("in delete_from_document_map()") try: # fetch project from Nomic project = AtlasProject(project_id=project_id, add_datums_if_exists=True) # delete the ids from Nomic - print("Deleting point from document map:", project.delete_data(ids)) + logging.info("Deleting point from document map:", project.delete_data(ids)) with project.wait_for_project_lock(): project.rebuild_maps() return "Successfully deleted from Nomic map" except Exception as e: - print(e) + logging.info(e) self.sentry.capture_exception(e) return "Error in deleting from document map: {e}" diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index f660b037..17888361 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -1,4 +1,5 @@ import inspect +import logging import os import time import traceback @@ -34,6 +35,8 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLAlchemyDatabase, aws: AWSStora self.posthog = posthog self.nomicService = nomicService + logging.info(f"Vector DB: {self.vdb}") + openai.api_key = os.environ["OPENAI_API_KEY"] self.embeddings = OpenAIEmbeddings( @@ -86,7 +89,7 @@ def getTopContexts(self, doc_string = f"Document: {doc.metadata['readable_filename']}{', page: ' + str(doc.metadata['pagenumber']) if doc.metadata['pagenumber'] else ''}\n{str(doc.page_content)}\n" num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore - print( + logging.info( f"tokens used/limit: {token_counter}/{token_limit}, tokens in chunk: {num_tokens}, total prompt cost (of these contexts): {prompt_cost}. 📄 File: {doc.metadata['readable_filename']}" ) if token_counter + num_tokens <= token_limit: @@ -96,9 +99,9 @@ def getTopContexts(self, # filled our token size, time to return break - print(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved") - print(f"Course: {course_name} ||| search_query: {search_query}") - print(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds") + logging.info(f"Total tokens used: {token_counter}. Docs used: {len(valid_docs)} of {len(found_docs)} docs retrieved") + logging.info(f"Course: {course_name} ||| search_query: {search_query}") + logging.info(f"⏰ ^^ Runtime of getTopContexts: {(time.monotonic() - start_time_overall):.2f} seconds") if len(valid_docs) == 0: return [] @@ -120,9 +123,9 @@ def getTopContexts(self, except Exception as e: # return full traceback to front end # err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore - err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.print_exc} \n{e}" # type: ignore + err: str = f"ERROR: In /getTopContexts. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.print_exc()} \n{e}" # type: ignore traceback.print_exc() - print(err) + logging.info(err) if self.sentry is not None: self.sentry.capture_exception(e) return err @@ -153,7 +156,7 @@ def getAll( def delete_data(self, course_name: str, s3_path: str, source_url: str): """Delete file from S3, Qdrant, and Supabase.""" - print(f"Deleting data for course {course_name}") + logging.info(f"Deleting data for course {course_name}") # add delete from doc map logic here try: # Delete file from S3 @@ -162,7 +165,7 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str): raise ValueError("S3_BUCKET_NAME environment variable is not set") identifier_key, identifier_value = ("s3_path", s3_path) if s3_path else ("url", source_url) - print(f"Deleting {identifier_value} from S3, Qdrant, and Supabase using {identifier_key}") + logging.info(f"Deleting {identifier_value} from S3, Qdrant, and Supabase using {identifier_key}") # Delete from S3 if identifier_key == "s3_path": @@ -177,32 +180,32 @@ def delete_data(self, course_name: str, s3_path: str, source_url: str): return "Success" except Exception as e: err: str = f"ERROR IN delete_data: Traceback: {traceback.extract_tb(e.__traceback__)}❌❌ Error in {inspect.currentframe().f_code.co_name}:{e}" # type: ignore - print(err) + logging.info(err) if self.sentry is not None: self.sentry.capture_exception(e) return err def delete_from_s3(self, bucket_name: str, s3_path: str): try: - print("Deleting from S3") + logging.info("Deleting from S3") response = self.aws.delete_file(bucket_name, s3_path) - print(f"AWS response: {response}") + logging.info(f"AWS response: {response}") except Exception as e: - print("Error in deleting file from s3:", e) + logging.info("Error in deleting file from s3:", e) if self.sentry is not None: self.sentry.capture_exception(e) def delete_from_qdrant(self, identifier_key: str, identifier_value: str): try: - print("Deleting from Qdrant") + logging.info("Deleting from Qdrant") response = self.vdb.delete_data(os.environ['QDRANT_COLLECTION_NAME'], identifier_key, identifier_value) - print(f"Qdrant response: {response}") + logging.info(f"Qdrant response: {response}") except Exception as e: if "timed out" in str(e): # Timed out is fine. Still deletes. pass else: - print("Error in deleting file from Qdrant:", e) + logging.info("Error in deleting file from Qdrant:", e) if self.sentry is not None: self.sentry.capture_exception(e) @@ -229,7 +232,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # ) # generated_queries = generate_queries.invoke({"original_query": search_query}) - # print("generated_queries", generated_queries) + # logging.info("generated_queries", generated_queries) # # 2. VECTOR SEARCH FOR EACH QUERY # batch_found_docs_nested: list[list[Document]] = self.batch_vector_search(search_queries=generated_queries, @@ -239,10 +242,10 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # # 3. RANK REMAINING DOCUMENTS -- good for parent doc padding of top 5 at the end. # found_docs = self.reciprocal_rank_fusion(batch_found_docs_nested) # found_docs = [doc for doc, score in found_docs] - # print(f"Num docs after re-ranking: {len(found_docs)}") + # logging.info(f"Num docs after re-ranking: {len(found_docs)}") # if len(found_docs) == 0: # return [] - # print(f"⏰ Total multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds") + # logging.info(f"⏰ Total multi-query processing runtime: {(time.monotonic() - mq_start_time):.2f} seconds") # # 4. FILTER DOCS # filtered_docs = filter_top_contexts(contexts=found_docs, user_query=search_query, timeout=30, max_concurrency=180) @@ -251,7 +254,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # # 5. TOP DOC CONTEXT PADDING // parent document retriever # final_docs = context_parent_doc_padding(filtered_docs, search_query, course_name) - # print(f"Number of final docs after context padding: {len(final_docs)}") + # logging.info(f"Number of final docs after context padding: {len(final_docs)}") # pre_prompt = "Please answer the following question. Use the context below, called your documents, only if it's helpful and don't use parts that are very irrelevant. It's good to quote from your documents directly, when you do always use Markdown footnotes for citations. Use react-markdown superscript to number the sources at the end of sentences (1, 2, 3...) and use react-markdown Footnotes to list the full document names for each number. Use ReactMarkdown aka 'react-markdown' formatting for super script citations, use semi-formal style. Feel free to say you don't know. \nHere's a few passages of the high quality documents:\n" # token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + @@ -263,7 +266,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # doc_string = f"Document: {doc['readable_filename']}{', page: ' + str(doc['pagenumber']) if doc['pagenumber'] else ''}\n{str(doc['text'])}\n" # num_tokens, prompt_cost = count_tokens_and_cost(doc_string) # type: ignore - # print(f"token_counter: {token_counter}, num_tokens: {num_tokens}, max_tokens: {token_limit}") + # logging.info(f"token_counter: {token_counter}, num_tokens: {num_tokens}, max_tokens: {token_limit}") # if token_counter + num_tokens <= token_limit: # token_counter += num_tokens # valid_docs.append(doc) @@ -271,9 +274,9 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # # filled our token size, time to return # break - # print(f"Total tokens used: {token_counter} Used {len(valid_docs)} of total unique docs {len(found_docs)}.") - # print(f"Course: {course_name} ||| search_query: {search_query}") - # print(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds") + # logging.info(f"Total tokens used: {token_counter} Used {len(valid_docs)} of total unique docs {len(found_docs)}.") + # logging.info(f"Course: {course_name} ||| search_query: {search_query}") + # logging.info(f"⏰ ^^ Runtime of getTopContextsWithMQR: {(time.monotonic() - start_time_overall):.2f} seconds") # if len(valid_docs) == 0: # return [] @@ -293,7 +296,7 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # except Exception as e: # # return full traceback to front end # err: str = f"ERROR: In /getTopContextsWithMQR. Course: {course_name} ||| search_query: {search_query}\nTraceback: {traceback.format_exc()}❌❌ Error in {inspect.currentframe().f_code.co_name}:\n{e}" # type: ignore - # print(err) + # logging.info(err) # sentry_sdk.capture_exception(e) # return err @@ -303,7 +306,7 @@ def format_for_json_mqr(self, found_docs) -> List[Dict]: """ for found_doc in found_docs: if "pagenumber" not in found_doc.keys(): - print("found no pagenumber") + logging.info("found no pagenumber") found_doc['pagenumber'] = found_doc['pagenumber_or_timestamp'] contexts = [ @@ -322,7 +325,7 @@ def format_for_json_mqr(self, found_docs) -> List[Dict]: def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, identifier_value: str): try: - print(f"Nomic delete. Course: {course_name} using {identifier_key}: {identifier_value}") + logging.info(f"Nomic delete. Course: {course_name} using {identifier_key}: {identifier_value}") response = self.sqlDb.getMaterialsForCourseAndKeyAndValue(course_name, identifier_key, identifier_value) data = response.data if not data: @@ -340,15 +343,15 @@ def delete_from_nomic_and_supabase(self, course_name: str, identifier_key: str, if self.nomicService is not None: self.nomicService.delete_from_document_map(project_id, nomic_ids_to_delete) except Exception as e: - print(f"Nomic Error in deleting. {identifier_key}: {identifier_value}", e) + logging.info(f"Nomic Error in deleting. {identifier_key}: {identifier_value}", e) if self.sentry is not None: self.sentry.capture_exception(e) try: - print(f"Supabase Delete. course: {course_name} using {identifier_key}: {identifier_value}") + logging.info(f"Supabase Delete. course: {course_name} using {identifier_key}: {identifier_value}") response = self.sqlDb.deleteMaterialsForCourseAndKeyAndValue(course_name, identifier_key, identifier_value) except Exception as e: - print(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) + logging.info(f"Supabase Error in delete. {identifier_key}: {identifier_value}", e) if self.sentry is not None: self.sentry.capture_exception(e) @@ -408,7 +411,7 @@ def _process_search_results(self, search_results, course_name): found_docs.append(Document(page_content=page_content, metadata=metadata)) except Exception as e: - print(f"Error in vector_search(), for course: `{course_name}`. Error: {e}") + logging.info(f"Error in vector_search(), for course: `{course_name}`. Error: {e}") if self.sentry is not None: self.sentry.capture_exception(e) return found_docs @@ -457,7 +460,7 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: """ for found_doc in found_docs: if "pagenumber" not in found_doc.metadata.keys(): - print("found no pagenumber") + logging.info("found no pagenumber") found_doc.metadata["pagenumber"] = found_doc.metadata["pagenumber_or_timestamp"] contexts = [ diff --git a/ai_ta_backend/service/workflow_service.py b/ai_ta_backend/service/workflow_service.py index 42671707..4f63c92d 100644 --- a/ai_ta_backend/service/workflow_service.py +++ b/ai_ta_backend/service/workflow_service.py @@ -1,4 +1,5 @@ import json +import logging import os import time from urllib.parse import quote @@ -144,7 +145,7 @@ def format_data(self, inputted, api_key: str, workflow_name): new_data[data[k]] = v return new_data except Exception as e: - print("Error in format_data: ", e) + logging.info("Error in format_data: ", e) def switch_workflow(self, id, api_key: str = "", activate: 'str' = 'True'): if not api_key: @@ -161,7 +162,7 @@ def switch_workflow(self, id, api_key: str = "", activate: 'str' = 'True'): def main_flow(self, name: str, api_key: str = "", data: str = ""): if not api_key: raise ValueError('api_key is required') - print("Starting") + logging.info("Starting") hookId = self.get_hook(name, api_key) hook = self.url + f"/form/{hookId}" @@ -175,22 +176,22 @@ def main_flow(self, name: str, api_key: str = "", data: str = ""): if len(ids) > 0: id = max(ids) + 1 - print("Execution found in supabase: ", id) + logging.info("Execution found in supabase: ", id) else: execution = self.get_executions(limit=1, api_key=api_key, pagination=False) - print("Got executions") + logging.info("Got executions") if execution: - print(execution) + logging.info(execution) id = int(execution[0]['id']) + 1 - print("Execution found through n8n: ", id) + logging.info("Execution found through n8n: ", id) else: raise Exception('No executions found') id = str(id) try: self.sqlDb.lockWorkflow(id) - print("inserted flow into supabase") + logging.info("inserted flow into supabase") self.execute_flow(hook, new_data) - print("Executed workflow") + logging.info("Executed workflow") except Exception as e: # TODO: Decrease number by one, is locked false # self.supabase_client.table('n8n_workflows').update({"latest_workflow_id": str(int(id) - 1), "is_locked": False}).eq('latest_workflow_id', id).execute() @@ -204,11 +205,11 @@ def main_flow(self, name: str, api_key: str = "", data: str = ""): executions = self.get_executions(20, id, True, api_key) while executions is None: executions = self.get_executions(20, id, True, api_key) - print("Can't find id in executions") + logging.info("Can't find id in executions") time.sleep(1) - print("Found id in executions ") + logging.info("Found id in executions ") self.sqlDb.deleteLatestWorkflowId(id) - print("Deleted id") + logging.info("Deleted id") except Exception as e: self.sqlDb.deleteLatestWorkflowId(id) return {"error": str(e)} diff --git a/ai_ta_backend/utils/context_parent_doc_padding.py b/ai_ta_backend/utils/context_parent_doc_padding.py index 25553248..e42aed5f 100644 --- a/ai_ta_backend/utils/context_parent_doc_padding.py +++ b/ai_ta_backend/utils/context_parent_doc_padding.py @@ -1,5 +1,6 @@ from concurrent.futures import ProcessPoolExecutor from functools import partial +import logging from multiprocessing import Manager import os import time @@ -13,7 +14,7 @@ def context_parent_doc_padding(found_docs, search_query, course_name): """ Takes top N contexts acquired from QRANT similarity search and pads them """ - print("inside main context padding") + logging.info("inside main context padding") start_time = time.monotonic() with Manager() as manager: @@ -33,7 +34,7 @@ def context_parent_doc_padding(found_docs, search_query, course_name): result_contexts = supabase_contexts_no_duplicates + list(qdrant_contexts) - print(f"⏰ Context padding runtime: {(time.monotonic() - start_time):.2f} seconds") + logging.info(f"⏰ Context padding runtime: {(time.monotonic() - start_time):.2f} seconds") return result_contexts @@ -80,10 +81,10 @@ def supabase_context_padding(doc, course_name, result_docs): # do the padding filename = data[0]['readable_filename'] contexts = data[0]['contexts'] - #print("no of contexts within the og doc: ", len(contexts)) + #logging.info("no of contexts within the og doc: ", len(contexts)) if 'chunk_index' in doc.metadata and 'chunk_index' in contexts[0].keys(): - #print("inside chunk index") + #logging.info("inside chunk index") # pad contexts by chunk index + 3 and - 3 target_chunk_index = doc.metadata['chunk_index'] for context in contexts: @@ -97,7 +98,7 @@ def supabase_context_padding(doc, course_name, result_docs): result_docs.append(context) elif doc.metadata['pagenumber'] != '': - #print("inside page number") + #logging.info("inside page number") # pad contexts belonging to same page number pagenumber = doc.metadata['pagenumber'] @@ -112,7 +113,7 @@ def supabase_context_padding(doc, course_name, result_docs): result_docs.append(context) else: - #print("inside else") + #logging.info("inside else") # refactor as a Supabase object and append context_dict = { 'text': doc.page_content, diff --git a/ai_ta_backend/utils/filtering_contexts.py b/ai_ta_backend/utils/filtering_contexts.py index 03deede0..83d502c1 100644 --- a/ai_ta_backend/utils/filtering_contexts.py +++ b/ai_ta_backend/utils/filtering_contexts.py @@ -45,7 +45,7 @@ # def filter_context(self, context, user_query, langsmith_prompt_obj): # final_prompt = str(langsmith_prompt_obj.format(context=context, user_query=user_query)) -# # print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") +# # logging.info(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") # try: # # completion = run_caii_hosted_llm(final_prompt) # # completion = run_replicate(final_prompt) @@ -53,7 +53,7 @@ # return {"completion": completion, "context": context} # except Exception as e: # sentry_sdk.capture_exception(e) -# print(f"Error: {e}") +# logging.info(f"Error: {e}") # def run_caii_hosted_llm(prompt, max_tokens=300, temp=0.3, **kwargs): # """ @@ -87,7 +87,7 @@ # # "max_new_tokens": 250, # # "presence_penalty": 1 # # }) -# print(output) +# logging.info(output) # return output # def run_anyscale(prompt, model_name="HuggingFaceH4/zephyr-7b-beta"): @@ -110,12 +110,12 @@ # ) # output = ret["choices"][0]["message"]["content"] # type: ignore -# print("Response from Anyscale:", output[:150]) +# logging.info("Response from Anyscale:", output[:150]) # # input_length = len(tokenizer.encode(prompt)) # # output_length = len(tokenizer.encode(output)) # # Input tokens {input_length}, output tokens: {output_length}" -# print(f"^^^^ one anyscale call Runtime: {(time.monotonic() - start_time):.2f} seconds.") +# logging.info(f"^^^^ one anyscale call Runtime: {(time.monotonic() - start_time):.2f} seconds.") # return output # def parse_result(result: str): @@ -130,7 +130,7 @@ # timeout: Optional[float] = None, # max_concurrency: Optional[int] = 180): -# print("⏰⏰⏰ Starting filter_top_contexts() ⏰⏰⏰") +# logging.info("⏰⏰⏰ Starting filter_top_contexts() ⏰⏰⏰") # timeout = timeout or float(os.environ["FILTER_TOP_CONTEXTS_TIMEOUT_SECONDS"]) # # langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr") # TOO UNSTABLE, service offline @@ -138,8 +138,8 @@ # posthog = Posthog(sync_mode=True, project_api_key=os.environ['POSTHOG_API_KEY'], host='https://app.posthog.com') # max_concurrency = min(100, len(contexts)) -# print("max_concurrency is max of 100, or len(contexts), whichever is less ---- Max concurrency:", max_concurrency) -# print("Num contexts to filter:", len(contexts)) +# logging.info("max_concurrency is max of 100, or len(contexts), whichever is less ---- Max concurrency:", max_concurrency) +# logging.info("Num contexts to filter:", len(contexts)) # # START TASKS # actor = AsyncActor.options(max_concurrency=max_concurrency, num_cpus=0.001).remote() # type: ignore @@ -161,10 +161,10 @@ # r['context'] for r in results if r and 'context' in r and 'completion' in r and parse_result(r['completion']) # ] -# print("🧠🧠 TOTAL DOCS PROCESSED BY ANYSCALE FILTERING:", len(results)) -# print("🧠🧠 TOTAL DOCS KEPT, AFTER FILTERING:", len(best_contexts_to_keep)) +# logging.info("🧠🧠 TOTAL DOCS PROCESSED BY ANYSCALE FILTERING:", len(results)) +# logging.info("🧠🧠 TOTAL DOCS KEPT, AFTER FILTERING:", len(best_contexts_to_keep)) # mqr_runtime = round(time.monotonic() - start_time, 2) -# print(f"⏰ Total elapsed time: {mqr_runtime} seconds") +# logging.info(f"⏰ Total elapsed time: {mqr_runtime} seconds") # posthog.capture('distinct_id_of_the_user', # event='filter_top_contexts', @@ -182,9 +182,9 @@ # def run_main(): # start_time = time.monotonic() # # final_passage_list = filter_top_contexts(contexts=CONTEXTS * 2, user_query=USER_QUERY) -# # print("✅✅✅ TOTAL included in results: ", len(final_passage_list)) -# print(f"⏰⏰⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") -# # print("Total contexts:", len(CONTEXTS) * 2) +# # logging.info("✅✅✅ TOTAL included in results: ", len(final_passage_list)) +# logging.info(f"⏰⏰⏰ Runtime: {(time.monotonic() - start_time):.2f} seconds") +# # logging.info("Total contexts:", len(CONTEXTS) * 2) # # ! CONDA ENV: llm-serving # if __name__ == "__main__": diff --git a/ai_ta_backend/utils/utils_tokenization.py b/ai_ta_backend/utils/utils_tokenization.py index db766e50..3736b418 100644 --- a/ai_ta_backend/utils/utils_tokenization.py +++ b/ai_ta_backend/utils/utils_tokenization.py @@ -1,3 +1,4 @@ +import logging import os from typing import Any @@ -55,7 +56,7 @@ def count_tokens_and_cost(prompt: str, completion_token_cost = 0.0001 / 1_000 else: # no idea of cost - print(f"NO IDEA OF COST, pricing not supported for model model: `{openai_model_name}`") + logging.info(f"NO IDEA OF COST, pricing not supported for model model: `{openai_model_name}`") prompt_token_cost = 0 completion_token_cost = 0 @@ -89,7 +90,7 @@ def analyze_conversations(supabase_client: Any = None): supabase_key=os.getenv('SUPABASE_API_KEY')) # type: ignore # Get all conversations response = supabase_client.table('llm-convo-monitor').select('convo').execute() - # print("total entries", response.data.count) + # logging.info("total entries", response.data.count) total_convos = 0 total_messages = 0 @@ -100,10 +101,10 @@ def analyze_conversations(supabase_client: Any = None): # for convo in response['data']: for convo in response.data: total_convos += 1 - # print(convo) + # logging.info(convo) # prase json from convo # parse json into dict - # print(type(convo)) + # logging.info(type(convo)) # convo = json.loads(convo) convo = convo['convo'] messages = convo['messages'] @@ -121,13 +122,13 @@ def analyze_conversations(supabase_client: Any = None): if role == 'user': num_tokens, cost = count_tokens_and_cost(prompt=content, openai_model_name=model_name) total_prompt_cost += cost - print(f'User Prompt: {content}, Tokens: {num_tokens}, cost: {cost}') + logging.info(f'User Prompt: {content}, Tokens: {num_tokens}, cost: {cost}') # If the message is from the assistant, it's a completion elif role == 'assistant': num_tokens_completion, cost_completion = count_tokens_and_cost(prompt='', completion=content, openai_model_name=model_name) total_completion_cost += cost_completion - print(f'Assistant Completion: {content}\nTokens: {num_tokens_completion}, cost: {cost_completion}') + logging.info(f'Assistant Completion: {content}\nTokens: {num_tokens_completion}, cost: {cost_completion}') return total_convos, total_messages, total_prompt_cost, total_completion_cost @@ -135,7 +136,7 @@ def analyze_conversations(supabase_client: Any = None): pass # if __name__ == '__main__': -# print('starting main') +# logging.info('starting main') # total_convos, total_messages, total_prompt_cost, total_completion_cost = analyze_conversations() -# print(f'total_convos: {total_convos}, total_messages: {total_messages}') -# print(f'total_prompt_cost: {total_prompt_cost}, total_completion_cost: {total_completion_cost}') +# logging.info(f'total_convos: {total_convos}, total_messages: {total_messages}') +# logging.info(f'total_prompt_cost: {total_prompt_cost}, total_completion_cost: {total_completion_cost}') diff --git a/docker-compose.yaml b/docker-compose.yaml index 9a5686f0..441f3d67 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,4 +1,6 @@ -version: "3.8" +# version: 3.8 + +# TODO: Update the network connections to each service (Minio, Redis), e.g. use `http://qdrant:6333` instead of `http://localhost:6333` services: # sqlite: @@ -11,45 +13,76 @@ services: redis: image: redis:latest ports: - - "6379:6379" + - 6379:6379 + networks: + - my-network volumes: - redis-data:/data qdrant: - image: generall/qdrant:latest - volumes: - - qdrant-data:/qdrant/storage + image: qdrant/qdrant:v1.9.5 + restart: always + container_name: qdrant ports: - - "6333:6333" # HTTP API - - "6334:6334" # gRPC API + - 6333:6333 + - 6334:6334 + expose: + - 6333 + - 6334 + - 6335 + volumes: + - ./qdrant_data:/qdrant/storage + - ./qdrant_config.yaml:/qdrant/config/production.yaml # Mount the config file directly as a volume + networks: + - my-network + healthcheck: + test: + [ + CMD, + curl, + -f, + -H, + { Authorization: Bearer qd-SbvSWrYpa473J33yPjdL }, + http://localhost:6333/health, + ] + interval: 30s + timeout: 10s + retries: 3 minio: - image: minio/minio + image: minio/minio:RELEASE.2024-06-13T22-53-53Z environment: - MINIO_ROOT_USER: "minioadmin" # Customize access key - MINIO_ROOT_PASSWORD: "minioadmin" # Customize secret key + MINIO_ROOT_USER: minioadmin # Customize access key + MINIO_ROOT_PASSWORD: minioadmin # Customize secret key command: server /data ports: - - "9000:9000" # Console access - - "9001:9001" # API access + - 9000:9000 # Console access + - 9001:9001 # API access + networks: + - my-network volumes: - minio-data:/data flask_app: - # build: . # Directory with Dockerfile for Flask app - image: kastanday/ai-ta-backend:gunicorn + build: . # Directory with Dockerfile for Flask app + # image: kastanday/ai-ta-backend:gunicorn ports: - - "8000:8000" + - 8000:8000 volumes: - ./db:/usr/src/app/db # Mount local directory to store SQLite database + networks: + - my-network depends_on: - # - sqlite - - redis - qdrant + - redis - minio +# declare the network resource +# this will allow you to use service discovery and address a container by its name from within the network +networks: + my-network: {} + volumes: - # sqlite-data: - redis-data: - qdrant-data: - minio-data: + redis-data: {} + qdrant-data: {} + minio-data: {} diff --git a/qdrant_config.yaml b/qdrant_config.yaml new file mode 100644 index 00000000..3adc978a --- /dev/null +++ b/qdrant_config.yaml @@ -0,0 +1,204 @@ +debug: false +log_level: INFO + +storage: + # Where to store all the data + # KEY: use the default location, then map that using docker -v nvme/storage:/qdrant/storage + storage_path: /qdrant/storage + + # Where to store snapshots + snapshots_path: /qdrant/storage/snapshots + + # Optional setting. Specify where else to store temp files as default is ./storage. + # Route to another location on your system to reduce network disk use. + temp_path: /qdrant/storage/temp + + # If true - a point's payload will not be stored in memory. + # It will be read from the disk every time it is requested. + # This setting saves RAM by (slightly) increasing the response time. + # Note: those payload values that are involved in filtering and are indexed - remain in RAM. + on_disk_payload: false + + # Write-ahead-log related configuration + wal: + # Size of a single WAL segment + wal_capacity_mb: 32 + + # Number of WAL segments to create ahead of actual data requirement + wal_segments_ahead: 0 + + # Normal node - receives all updates and answers all queries + node_type: Normal + + # Listener node - receives all updates, but does not answer search/read queries + # Useful for setting up a dedicated backup node + # node_type: "Listener" + + performance: + # Number of parallel threads used for search operations. If 0 - auto selection. + max_search_threads: 4 + # Max total number of threads, which can be used for running optimization processes across all collections. + # Note: Each optimization thread will also use `max_indexing_threads` for index building. + # So total number of threads used for optimization will be `max_optimization_threads * max_indexing_threads` + max_optimization_threads: 1 + + optimizers: + # The minimal fraction of deleted vectors in a segment, required to perform segment optimization + deleted_threshold: 0.2 + + # The minimal number of vectors in a segment, required to perform segment optimization + vacuum_min_vector_number: 1000 + + # Target amount of segments optimizer will try to keep. + # Real amount of segments may vary depending on multiple parameters: + # - Amount of stored points + # - Current write RPS + # + # It is recommended to select default number of segments as a factor of the number of search threads, + # so that each segment would be handled evenly by one of the threads. + # If `default_segment_number = 0`, will be automatically selected by the number of available CPUs + default_segment_number: 0 + + # Do not create segments larger this size (in KiloBytes). + # Large segments might require disproportionately long indexation times, + # therefore it makes sense to limit the size of segments. + # + # If indexation speed have more priority for your - make this parameter lower. + # If search speed is more important - make this parameter higher. + # Note: 1Kb = 1 vector of size 256 + # If not set, will be automatically selected considering the number of available CPUs. + max_segment_size_kb: null + + # Maximum size (in KiloBytes) of vectors to store in-memory per segment. + # Segments larger than this threshold will be stored as read-only memmaped file. + # To enable memmap storage, lower the threshold + # Note: 1Kb = 1 vector of size 256 + # To explicitly disable mmap optimization, set to `0`. + # If not set, will be disabled by default. + memmap_threshold_kb: null + + # Maximum size (in KiloBytes) of vectors allowed for plain index. + # Default value based on https://github.com/google-research/google-research/blob/master/scann/docs/algorithms.md + # Note: 1Kb = 1 vector of size 256 + # To explicitly disable vector indexing, set to `0`. + # If not set, the default value will be used. + indexing_threshold_kb: 20000 + + # Interval between forced flushes. + flush_interval_sec: 5 + + # Max number of threads, which can be used for optimization per collection. + # Note: Each optimization thread will also use `max_indexing_threads` for index building. + # So total number of threads used for optimization will be `max_optimization_threads * max_indexing_threads` + # If `max_optimization_threads = 0`, optimization will be disabled. + max_optimization_threads: 1 + + # Default parameters of HNSW Index. Could be overridden for each collection or named vector individually + hnsw_index: + # Number of edges per node in the index graph. Larger the value - more accurate the search, more space required. + m: 16 + # Number of neighbours to consider during the index building. Larger the value - more accurate the search, more time required to build index. + ef_construct: 100 + # Minimal size (in KiloBytes) of vectors for additional payload-based indexing. + # If payload chunk is smaller than `full_scan_threshold_kb` additional indexing won't be used - + # in this case full-scan search should be preferred by query planner and additional indexing is not required. + # Note: 1Kb = 1 vector of size 256 + full_scan_threshold_kb: 10000 + # Number of parallel threads used for background index building. If 0 - auto selection. + max_indexing_threads: 0 + # Store HNSW index on disk. If set to false, index will be stored in RAM. Default: false + on_disk: false + # Custom M param for hnsw graph built for payload index. If not set, default M will be used. + payload_m: null + +service: + # Maximum size of POST data in a single request in megabytes + max_request_size_mb: 32 + + # Number of parallel workers used for serving the api. If 0 - equal to the number of available cores. + # If missing - Same as storage.max_search_threads + max_workers: 0 + + # Host to bind the service on + host: 0.0.0.0 + + # HTTP(S) port to bind the service on + http_port: 6333 + + # gRPC port to bind the service on. + # If `null` - gRPC is disabled. Default: null + grpc_port: 6334 + # Uncomment to enable gRPC: + # grpc_port: 6334 + + # Enable CORS headers in REST API. + # If enabled, browsers would be allowed to query REST endpoints regardless of query origin. + # More info: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS + # Default: true + enable_cors: true + + # Use HTTPS for the REST API + enable_tls: false + + # Check user HTTPS client certificate against CA file specified in tls config + verify_https_client_certificate: false + + # Set an api-key. + # If set, all requests must include a header with the api-key. + # example header: `api-key: ` + # + # If you enable this you should also enable TLS. + # (Either above or via an external service like nginx.) + # Sending an api-key over an unencrypted channel is insecure. + # + # Uncomment to enable. + api_key: qd-SbvSWrYpa473J33yPjdL + +cluster: + # Use `enabled: true` to run Qdrant in distributed deployment mode + enabled: false + + # Configuration of the inter-cluster communication + p2p: + # Port for internal communication between peers + port: 6335 + + # Use TLS for communication between peers + enable_tls: false + + # Configuration related to distributed consensus algorithm + consensus: + # How frequently peers should ping each other. + # Setting this parameter to lower value will allow consensus + # to detect disconnected nodes earlier, but too frequent + # tick period may create significant network and CPU overhead. + # We encourage you NOT to change this parameter unless you know what you are doing. + tick_period_ms: 100 + +# Set to true to prevent service from sending usage statistics to the developers. +# Read more: https://qdrant.tech/documentation/telemetry +telemetry_disabled: false + +# TLS configuration. +# Required if either service.enable_tls or cluster.p2p.enable_tls is true. +tls: + # Server certificate chain file + cert: ./tls/cert.pem + + # Server private key file + key: ./tls/key.pem + + # Certificate authority certificate file. + # This certificate will be used to validate the certificates + # presented by other nodes during inter-cluster communication. + # + # If verify_https_client_certificate is true, it will verify + # HTTPS client certificate + # + # Required if cluster.p2p.enable_tls is true. + ca_cert: ./tls/cacert.pem + + # TTL, in seconds, to re-load certificate from disk. Useful for certificate rotations, + # Only works for HTTPS endpoints, gRPC endpoints (including intra-cluster communication) + # doesn't support certificate re-load + cert_ttl: 3600 From 5c751bc43a6d8e335bca600d5f8827ad0befc0a1 Mon Sep 17 00:00:00 2001 From: rohan-uiuc Date: Thu, 20 Jun 2024 17:34:05 -0500 Subject: [PATCH 10/11] Major revamp from a script -> endpoint w/ agent, need to fix the llm and model used in @tool --- ai_ta_backend/main.py | 19 +- ai_ta_backend/service/poi_agent_service_v2.py | 126 ++++++ ai_ta_backend/utils/agent_utils.py | 370 ++++++++++++++++++ requirements.txt | 10 +- 4 files changed, 513 insertions(+), 12 deletions(-) create mode 100644 ai_ta_backend/service/poi_agent_service_v2.py create mode 100644 ai_ta_backend/utils/agent_utils.py diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 38bbc12c..fd7bcbba 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -36,11 +36,10 @@ ) from ai_ta_backend.service.export_service import ExportService from ai_ta_backend.service.nomic_service import NomicService +from ai_ta_backend.service.poi_agent_service_v2 import POIAgentService from ai_ta_backend.service.posthog_service import PosthogService from ai_ta_backend.service.retrieval_service import RetrievalService from ai_ta_backend.service.sentry_service import SentryService -from ai_ta_backend.service.poi_agent_service import generate_response ## need to add langchain-community langchain-core langchain-openai to requirements.txt - # from ai_ta_backend.beam.nomic_logging import create_document_map from ai_ta_backend.service.workflow_service import WorkflowService from ai_ta_backend.extensions import db @@ -55,7 +54,7 @@ #app.config['SERVER_TIMEOUT'] = 1000 # seconds # load API keys from globally-availabe .env file -load_dotenv() +load_dotenv(override=True) @app.route('/') def index() -> Response: @@ -226,9 +225,9 @@ def createConversationMap(service: NomicService): @app.route('/query_sql_agent', methods=['POST']) -def query_sql_agent(): +def query_sql_agent(service: POIAgentService): data = request.get_json() - user_input = data.get('query') + user_input = data["query"] system_message = SystemMessage(content="you are a helpful assistant and need to provide answers in text format about the plants found in India. If the Question is not related to plants in India answer 'I do not have any information on this.'") if not user_input: @@ -237,7 +236,7 @@ def query_sql_agent(): try: user_01 = HumanMessage(content=user_input) inputs = {"messages": [system_message,user_01]} - response = generate_response(inputs) + response = service.run_workflow(inputs) return str(response), 200 except Exception as e: return jsonify({"error": str(e)}), 500 @@ -518,11 +517,13 @@ def configure(binder: Binder) -> None: app.config['SQLALCHEMY_DATABASE_URI'] = url db.init_app(app) db.create_all() - binder.bind(SQLAlchemyDatabase, to=db, scope=SingletonScope) + binder.bind(SQLAlchemyDatabase, to=SQLAlchemyDatabase(db), scope=SingletonScope) sql_bound = True break - # if os.getenv(POI_SQL_DB_NAME): # type: ignore - # binder.bind(SQLAlchemyDatabase, to=POISQLDatabase, scope=SingletonScope) + if os.getenv("POI_SQL_DB_NAME"): + logging.info(f"Binding to POI SQL database with URL: {os.getenv('POI_SQL_DB_NAME')}") + binder.bind(POISQLDatabase, to=POISQLDatabase(db), scope=SingletonScope) + binder.bind(POIAgentService, to=POIAgentService, scope=SingletonScope) # Conditionally bind databases based on the availability of their respective secrets if all(os.getenv(key) for key in ["QDRANT_URL", "QDRANT_API_KEY", "QDRANT_COLLECTION_NAME"]) or any(os.getenv(key) for key in ["PINECONE_API_KEY", "PINECONE_PROJECT_NAME"]): logging.info("Binding to Qdrant database") diff --git a/ai_ta_backend/service/poi_agent_service_v2.py b/ai_ta_backend/service/poi_agent_service_v2.py new file mode 100644 index 00000000..1a0e7b9c --- /dev/null +++ b/ai_ta_backend/service/poi_agent_service_v2.py @@ -0,0 +1,126 @@ +import os +from injector import inject +from langchain_openai import ChatOpenAI +from pydantic import BaseModel +from ai_ta_backend.database.poi_sql import POISQLDatabase +from langgraph.graph import StateGraph, END +from langchain_openai import ChatOpenAI +from langchain_community.utilities.sql_database import SQLDatabase + + +from langchain_openai import ChatOpenAI + + +from langchain.tools import tool, StructuredTool +from langgraph.prebuilt.tool_executor import ToolExecutor + + +from typing import TypedDict, Annotated, Sequence +import operator +from langchain_core.messages import BaseMessage + +from langgraph.prebuilt import ToolInvocation +import json +from langchain_core.messages import FunctionMessage +from langchain_core.utils.function_calling import convert_to_openai_function +from langgraph.graph import StateGraph, END + +from ai_ta_backend.utils.agent_utils import generate_response_agent, initalize_sql_agent +import traceback + +##### Setting up the Graph Nodes, Edges and message communication + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], operator.add] + +class POIInput(BaseModel): + input: str + +@tool("plants_sql_tool", return_direct=True, args_schema=POIInput) +def generate_sql_query(input:str) -> str: + """Given a query looks for the three most relevant SQL sample queries""" + user_question = input + llm = ChatOpenAI(model="gpt-4o", temperature=0) + ### DATABASE + db = SQLDatabase.from_uri(f"sqlite:///{os.environ['POI_SQL_DB_NAME']}") + sql_agent = initalize_sql_agent(llm, db) + response = generate_response_agent(sql_agent,user_question) + return response['output'] + +class POIAgentService: + @inject + def __init__(self, poi_sql_db: POISQLDatabase): + self.poi_sql_db = poi_sql_db + self.model = ChatOpenAI(model="gpt-4o", temperature=0) + # self.tools = [StructuredTool.from_function(self.generate_sql_query, name="Run SQL Query", args_schema=POIInput)] + self.tools = [generate_sql_query] + self.tool_executor = ToolExecutor(self.tools) + self.functions = [convert_to_openai_function(t) for t in self.tools] + self.model = self.model.bind_functions(self.functions) + self.workflow = self.initialize_workflow(self.model) + + + + # Define the function that determines whether to continue or not + def should_continue(self, state): + messages = state['messages'] + last_message = messages[-1] + # If there is no function call, then we finish + if "function_call" not in last_message.additional_kwargs: + return "end" + # Otherwise if there is, we continue + else: + return "continue" + + # Define the function that calls the model + def call_model(self, state): + messages = state['messages'] + response = self.model.invoke(messages) + # We return a list, because this will get added to the existing list + return {"messages": [response]} + + # Define the function to execute tools + def call_tool(self, state): + messages = state['messages'] + # Based on the continue condition + # we know the last message involves a function call + last_message = messages[-1] + # We construct an ToolInvocation from the function_call + action = ToolInvocation( + tool=last_message.additional_kwargs["function_call"]["name"], + tool_input=json.loads(last_message.additional_kwargs["function_call"]["arguments"]), + ) + print(f"The agent action is {action}") + # We call the tool_executor and get back a response + response = self.tool_executor.invoke(action) + print(f"The tool result is: {response}") + # We use the response to create a FunctionMessage + function_message = FunctionMessage(content=str(response), name=action.tool) + # We return a list, because this will get added to the existing list + return {"messages": [function_message]} + + def initialize_workflow(self, agent): + workflow = StateGraph(AgentState) + workflow.add_node("agent", self.call_model) + workflow.add_node("action", self.call_tool) + workflow.set_entry_point("agent") + workflow.add_conditional_edges( + "agent", + self.should_continue, + { + "continue": "action", + "end": END + } + ) + workflow.add_edge('action', 'agent') + return workflow.compile() + + def run_workflow(self, user_input): + #agent = initialize_agent() + try: + + output = self.workflow.invoke(user_input) + return output + except Exception as e: + traceback.print_exc() + return str(e) \ No newline at end of file diff --git a/ai_ta_backend/utils/agent_utils.py b/ai_ta_backend/utils/agent_utils.py new file mode 100644 index 00000000..5181c828 --- /dev/null +++ b/ai_ta_backend/utils/agent_utils.py @@ -0,0 +1,370 @@ +import json +# from ai_ta_backend.model.response import FunctionMessage, ToolInvocation +from dotenv import load_dotenv +from langchain_community.utilities.sql_database import SQLDatabase +from langchain_openai import ChatOpenAI, OpenAI +from langchain_community.agent_toolkits import create_sql_agent +from langchain_core.prompts import ( + ChatPromptTemplate, + FewShotPromptTemplate, + MessagesPlaceholder, + PromptTemplate, + SystemMessagePromptTemplate, +) +import os +import logging +from flask_sqlalchemy import SQLAlchemy + +from langchain_openai import ChatOpenAI + +from operator import itemgetter + +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool +from langchain_community.agent_toolkits import create_sql_agent +from langchain.tools import BaseTool, StructuredTool, Tool, tool +import random +from langgraph.prebuilt.tool_executor import ToolExecutor +from langchain.tools.render import format_tool_to_openai_function + + +from typing import TypedDict, Annotated, Sequence +import operator +from langchain_core.messages import BaseMessage + +from langchain_core.agents import AgentFinish +from langgraph.prebuilt import ToolInvocation +import json +from langchain_core.messages import FunctionMessage +from langchain_community.utilities import SQLDatabase +from langchain_community.vectorstores import FAISS +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings + + + +def get_dynamic_prompt_template(): + + examples = [ + { + "input": "How many accepted names are only distributed in Karnataka?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "Karnataka%"));' + }, + { + "input": "How many names were authored by Roxb?", + "query": 'SELECT COUNT(*) as unique_pairs_count FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Author_Name" LIKE "%Roxb%" AND "Record_Type_Code" IN ("AN", "SN"));' + }, + { + "input": "How many species have distributions in Myanmar, Meghalaya and Andhra Pradesh?", + "query": 'SELECT COUNT(*) FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names common to Myanmar, Meghalaya, Odisha, Andhra Pradesh.", + "query": 'SELECT DISTINCT Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%"));' + }, + { + "input": "List the accepted names that represent 'tree'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND "Additional_Details_2" LIKE "%tree%");' + }, + { + "input": "List the accepted names linked with Endemic tag.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%");' + }, + { + "input": "List the accepted names published in Fl. Brit. India [J. D. Hooker].", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" in ("AN", "SN") AND ("Publication" LIKE "%Fl. Brit. India [J. D. Hooker]%" OR "Publication" LIKE "%[J. D. Hooker]%" OR "Publication" LIKE "%Fl. Brit. India%");' + }, + { + "input": "How many accepted names have ‘Silhet’/ ‘Sylhet’ in their Type?", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "TY" AND ("Additional_Details_2" LIKE "%Silhet%" OR "Additional_Details_2" LIKE "%Sylhet%"));' + }, + { + "input": "How many species were distributed in Sikkim and Meghalaya?", + "query": 'SELECT COUNT(*) AS unique_pairs FROM (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "List the accepted names common to Kerala, Tamil Nadu, Andhra Pradesh, Karnataka, Maharashtra, Odisha, Meghalaya and Myanmar.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Meghalaya%" AND "Additional_Details_2" LIKE "%Odisha%" AND "Additional_Details_2" LIKE "%Andhra Pradesh%" AND "Additional_Details_2" LIKE "%Kerala%" AND "Additional_Details_2" LIKE "%Tamil Nadu%" AND "Additional_Details_2" LIKE "%Karnataka%" AND "Additional_Details_2" LIKE "%Maharashtra%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Jammu & Kashmir, Himachal, Nepal, Sikkim, Bhutan, Arunachal Pradesh and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Sikkim%" AND "Additional_Details_2" LIKE "%Bhutan%" AND "Additional_Details_2" LIKE "%Arunachal Pradesh%" AND "Additional_Details_2" LIKE "%China%"));' + }, + { + "input": "List the accepted names common to Europe, Afghanistan, Austria, Belgium, Czechoslovakia, Denmark, France, Greece, Hungary, Italy, Moldava, Netherlands, Poland, Romania, Spain, Switzerland, Jammu & Kashmir, Himachal, Nepal, and China.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Europe%" AND Additional_Details_2 LIKE "%Afghanistan%" AND "Additional_Details_2" LIKE "%Jammu & Kashmir%" AND "Additional_Details_2" LIKE "%Himachal%" AND "Additional_Details_2" LIKE "%Nepal%" AND "Additional_Details_2" LIKE "%Austria%" AND "Additional_Details_2" LIKE "%Belgium%" AND "Additional_Details_2" LIKE "%Czechoslovakia%" AND "Additional_Details_2" LIKE "%China%" AND "Additional_Details_2" LIKE "%Denmark%" AND "Additional_Details_2" LIKE "%Greece%" AND "Additional_Details_2" LIKE "%France%" AND "Additional_Details_2" LIKE "%Hungary%" AND "Additional_Details_2" LIKE "%Italy%" AND "Additional_Details_2" LIKE "%Moldava%" AND "Additional_Details_2" LIKE "%Netherlands%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Poland%" AND "Additional_Details_2" LIKE "%Romania%" AND "Additional_Details_2" LIKE "%Spain%" AND "Additional_Details_2" LIKE "%Switzerland%"));' + }, + { + "input": "List the species which are distributed in Sikkim and Meghalaya.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%Sikkim%" AND Additional_Details_2 LIKE "%Meghalaya%"));' + }, + { + "input": "How many species are common to America, Europe, Africa, Asia, and Australia?", + "query": 'SELECT COUNT(*) AS unique_pairs IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%America%" AND Additional_Details_2 LIKE "%Europe%" AND "Additional_Details_2" LIKE "%Africa%" AND "Additional_Details_2" LIKE "%Asia%" AND "Additional_Details_2" LIKE "%Australia%"));' + }, + { + "input": "List the species names common to India and Myanmar, Malaysia, Indonesia, and Australia.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number","Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND ("Additional_Details_2" LIKE "%India%" AND Additional_Details_2 LIKE "%Myanmar%" AND Additional_Details_2 LIKE "%Malaysia%" AND Additional_Details_2 LIKE "%Indonesia%" AND Additional_Details_2 LIKE "%Australia%"));' + }, + { + "input": "List all plants which are tagged as urban.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Urban" = "YES";' + }, + { + "input": "List all plants which are tagged as fruit.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Fruit" = "YES";' + }, + { + "input": "List all plants which are tagged as medicinal.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Medicinal" = "YES";' + }, + { + "input": "List all family names which are gymnosperms.", + "query": 'SELECT DISTINCT "Family_Name" FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Gymnosperms";' + }, + { + "input": "How many accepted names are tagged as angiosperms?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Record_Type_Code" IN ("AN", "SN") AND "Groups" = "Angiosperms";' + }, + { + "input": "How many accepted names belong to the 'Saxifraga' genus?", + "query": 'SELECT COUNT(DISTINCT "Scientific_Name") FROM plants WHERE "Genus_Name" = "Saxifraga";' + }, + { + "input": "List the accepted names tagged as 'perennial herb' or 'climber'.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "HB" AND ("Additional_Details_2" LIKE "%perennial herb%" OR "Additional_Details_2" LIKE "%climber%"));' + }, + { + "input": "How many accepted names are native to South Africa?", + "query": 'SELECT COUNT(*) FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%native%" AND "Additional_Details_2" LIKE "%south%" AND "Additional_Details_2" LIKE "%africa%");' + + }, + { + "input": "List the accepted names which were introduced and naturalized.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%introduced%" AND "Additional_Details_2" LIKE "%naturalized%");' + }, + { + "input": "List all ornamental plants.", + "query": 'SELECT DISTINCT "Scientific_Name FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "RE" AND "Additional_Details_2" LIKE "%ornamental%");' + }, + { + "input": "How many plants from the 'Leguminosae' family have a altitudinal range up to 1000 m?", + "query": 'SELECT COUNT(*) FROM plants WHERE "Record_Type_Code" = "AL" AND "Family_Name" = "Leguminosae" AND "Additional_Details_2" LIKE "%1000%";' + }, + { + "input": "List the accepted names linked with the 'endemic' tag for Karnataka.", + "query": 'SELECT DISTINCT "Scientific_Name" FROM plants WHERE ("Family_Number", "Genus_Number", "Accepted_name_number") IN (SELECT DISTINCT "Family_Number", "Genus_Number", "Accepted_name_number" FROM plants WHERE "Record_Type_Code" = "DB" AND "Additional_Details_2" LIKE "%Endemic%" AND "Additional_Details_2" LIKE "%Karnataka%");' + }, + {"input": "List all the accepted names under the family 'Gnetaceae'.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Family_Name" = "Gnetaceae" + ) +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +"""}, + { + "input": "List all the accepted species that are introduced.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'RE'and "Additional_Details_2" LIKE '%cultivated%' +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": "List all the accepted names with type 'Cycad'", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'HB'and "Additional_Details_2" LIKE '%Cycad%' + +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": "List all the accepted names under the genus 'Cycas' with more than two synonyms.", + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Genus_Number" IN ( + SELECT DISTINCT "Genus_Number" FROM plants WHERE "Genus_Name" = 'Cycas' + ) + AND "Family_Number" IN ( + SELECT DISTINCT "Family_Number" FROM plants WHERE "Genus_Name" = 'Cycas' + ) + AND "Synonym_Number" > 2 +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input":'List all the accepted names published in Asian J. Conservation Biol.', + "query": """ + SELECT DISTINCT "Scientific_Name" + FROM plants + WHERE "Record_Type_Code" = 'AN' AND "Publication" LIKE '%Asian J. Conservation Biol%'; + +""", + }, + { + "input": 'List all the accepted names linked with endemic tag.', + "query": """ +SELECT DISTINCT "Scientific_Name" FROM plants +JOIN ( + SELECT "Family_Number", "Genus_Number", "Accepted_name_number" + FROM plants + WHERE "Record_Type_Code" = 'DB'and "Additional_Details_2" LIKE '%Endemic%' + +) b +ON plants."Genus_Number" = b."Genus_Number" AND plants."Accepted_name_number" = b."Accepted_name_number" AND plants."Family_Number" = b."Family_Number" +WHERE plants."Record_Type_Code" = 'AN'; +""", + }, + { + "input": 'List all the accepted names that have no synonyms.' , + "query": """ +SELECT DISTINCT a."Scientific_Name" FROM plants a +group by a."Family_Number",a."Genus_Number",a."Accepted_name_number" +HAVING SUM(a."Synonym_Number") = 0 AND a."Accepted_name_number" > 0; +""", + }, + { + "input": 'List all the accepted names authored by Roxb.', + "query": """ +SELECT "Scientific_Name" +FROM plants +WHERE "Record_Type_Code" = 'AN'AND "Author_Name" LIKE '%Roxb%'; +""", + }, + { + "input": 'List all genera within each family', + "query": """ +SELECT "Family_Name", "Genus_Name" +FROM plants +WHERE "Record_Type_Code" = 'GE'; +""", + }, + { + "input": 'Did Minq. discovered Cycas ryumphii?', + "query": """SELECT + CASE + WHEN EXISTS ( + SELECT 1 + FROM plants as a + WHERE a."Scientific_Name" = 'Cycas rumphii' + AND a."Author_Name" = 'Miq.' + ) THEN 'TRUE' + ELSE 'FALSE' + END AS ExistsCheck; +"""}, + + ] + + + example_selector = SemanticSimilarityExampleSelector.from_examples( + examples, + OpenAIEmbeddings(), + FAISS, + k=5, + input_keys=["input"], + ) + + + prefix_prompt = """ + You are an agent designed to interact with a SQL database. + Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. + You can order the results by a relevant column to return the most interesting examples in the database. + Never query for all the columns from a specific table, only ask for the relevant columns given the question. + You have access to tools for interacting with the database. + Only use the given tools. Only use the information returned by the tools to construct your final answer. + You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. + + - Restrict your queries to the "plants" table. + - Do not return more than {top_k} rows unless specified otherwise. + - Add a limit of 25 at the end of SQL query. + - If the SQLite query returns zero rows, return a message indicating the same. + - Only refer to the data contained in the {table_info} table. Do not fabricate any data. + - For filtering based on string comparison, always use the LIKE operator and enclose the string in `%`. + - Queries on the `Additional_Details_2` column should use sub-queries involving `Family_Number`, `Genus_Number` and `Accepted_name_number`. + + Refer to the table description below for more details on the columns: + 1. **Record_Type_Code**: Contains text codes indicating the type of information in the row. + - FA: Family Name, Genus Name, Scientific Name + - TY: Type + - GE: Genus Name + - AN: Family Name (Accepted Name), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + - HB: Habit + - DB: Distribution/location of the plant + - RE: Remarks + - SN: Family Name (Synonym), Genus Name, Scientific Name, Author Name, Publication, Volume:Page, Year of Publication + 2. **Family_Name**: Contains the Family Name of the plant. + 3. **Genus_Name**: Contains the Genus Name of the plant. + 4. **Scientific_Name**: Contains the Scientific Name of the plant species. + 5. **Publication_Name**: Name of the journal or book where the plant discovery information is published. Use LIKE for queries. + 6. **Volume:_Page**: The volume and page number of the publication. + 7. **Year_of_Publication**: The year in which the plant information was published. + 8. **Author_Name**: May contain multiple authors separated by `&`. Use LIKE for queries. + 9. **Additional_Details**: Contains type, habit, distribution, and remarks. Use LIKE for queries. + - Type: General location information. + - Remarks: Location information about cultivation or native area. + - Distribution: Locations where the plant is common. May contain multiple locations, use LIKE for queries. + 10. **Groups**: Contains either "Gymnosperms" or "Angiosperms". + 11. **Urban**: Contains either "YES" or "NO". Specifies whether the plant is urban. + 12. **Fruit**: Contains either "YES" or "NO". Specifies whether the plant is a fruit plant. + 13. **Medicinal**: Contains either "YES" or "NO". Specifies whether the plant is medicinal. + 14. **Genus_Number**: Contains the Genus Number of the plant. + 15. **Accepted_name_number**: Contains the Accepted Name Number of the plant. + + Below are examples of questions and their corresponding SQL queries. + """ + + + + agent_prompt = PromptTemplate.from_template("User input: {input}\nSQL Query: {query}") + agent_prompt_obj = FewShotPromptTemplate( + example_selector=example_selector, + example_prompt=agent_prompt, + prefix=prefix_prompt, + suffix="", + input_variables=["input"], + ) + + full_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessagePromptTemplate(prompt=agent_prompt_obj), + ("human", "{input}"), + MessagesPlaceholder("agent_scratchpad"), + ] + ) + return full_prompt + +def initalize_sql_agent(llm, db): + + dynamic_few_shot_prompt = get_dynamic_prompt_template() + + agent = create_sql_agent(llm, db=db, prompt=dynamic_few_shot_prompt, agent_type="openai-tools", verbose=True) + + return agent + +def generate_response_agent(agent,user_question): + response = agent.invoke({"input": user_question}) + return response \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eed670aa..dfb67bb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,16 +19,20 @@ typing-inspect==0.9.0 typing_extensions==4.8.0 # Utils -tiktoken==0.5.1 +tiktoken==0.7.0 python-dotenv==1.0.0 pydantic==1.10.13 # pydantic v1 works better for ray flask-executor==1.0.0 # AI & core services nomic==2.0.14 -openai==0.28.1 -langchain==0.0.331 +openai==1.31.2 +langchain==0.2.2 langchainhub==0.1.14 +langgraph==0.0.69 +faiss-cpu==1.8.0 +langchain-community==0.2.3 +langchain-openai==0.1.8 # Data boto3==1.28.79 From c9d12c016cd2d460c9f893fef88345d598f862cf Mon Sep 17 00:00:00 2001 From: Kastan Day Date: Tue, 25 Jun 2024 11:11:11 -0700 Subject: [PATCH 11/11] Minor cleanup docker-compose --- docker-compose.yaml | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 441f3d67..c4d3c68d 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,15 +1,4 @@ -# version: 3.8 - -# TODO: Update the network connections to each service (Minio, Redis), e.g. use `http://qdrant:6333` instead of `http://localhost:6333` - services: - # sqlite: - # image: nouchka/sqlite3 - # volumes: - # - sqlite-data:/root/db - # - ./init-scripts:/docker-entrypoint-initdb.d # Mount initialization scripts - # command: [ "sqlite3", "/root/db/sqlite.db", "-cmd", ".tables" ] - redis: image: redis:latest ports: