From 5951fe43fc477f8e2fce87965f37a57ef88dc9c9 Mon Sep 17 00:00:00 2001 From: rohanmarwaha Date: Thu, 7 Mar 2024 19:15:00 -0600 Subject: [PATCH] Update OpenAI API type to be fetched from environment variable --- ai_ta_backend/service/nomic_service.py | 4 ++-- ai_ta_backend/service/retrieval_service.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ai_ta_backend/service/nomic_service.py b/ai_ta_backend/service/nomic_service.py index feae1473..f9d33a59 100644 --- a/ai_ta_backend/service/nomic_service.py +++ b/ai_ta_backend/service/nomic_service.py @@ -196,7 +196,7 @@ def log_convo_to_nomic(self, course_name: str, conversation) -> str | None: }] # create embeddings - embeddings_model = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # type: ignore + embeddings_model = OpenAIEmbeddings(openai_api_type=os.environ['OPENAI_API_TYPE']) embeddings = embeddings_model.embed_documents(user_queries) # add embeddings to the project - create a new function for this @@ -380,7 +380,7 @@ def create_nomic_map(self, course_name: str, log_data: list): metadata.append(metadata_row) metadata = pd.DataFrame(metadata) - embeddings_model = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # type: ignore + embeddings_model = OpenAIEmbeddings(openai_api_type=os.environ['OPENAI_API_TYPE']) embeddings = embeddings_model.embed_documents(user_queries) # create Atlas project diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index 9d39146c..e51123f5 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -18,8 +18,6 @@ from ai_ta_backend.service.sentry_service import SentryService from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost -OPENAI_API_TYPE = "azure" # "openai" or "azure" - class RetrievalService: """ @@ -41,7 +39,7 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos self.embeddings = OpenAIEmbeddings( model='text-embedding-ada-002', openai_api_base=os.getenv("AZURE_OPENAI_ENDPOINT"), # type:ignore - openai_api_type=OPENAI_API_TYPE, + openai_api_type=os.environ['OPENAI_API_TYPE'], openai_api_key=os.getenv("AZURE_OPENAI_KEY"), # type:ignore openai_api_version=os.getenv("OPENAI_API_VERSION"), # type:ignore ) @@ -52,7 +50,7 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos openai_api_base=os.getenv("AZURE_OPENAI_ENDPOINT"), # type:ignore openai_api_key=os.getenv("AZURE_OPENAI_KEY"), # type:ignore openai_api_version=os.getenv("OPENAI_API_VERSION"), # type:ignore - openai_api_type=OPENAI_API_TYPE, + openai_api_type=os.environ['OPENAI_API_TYPE'], ) def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]: @@ -336,7 +334,7 @@ def vector_search(self, search_query, course_name): top_n = 80 # EMBED openai_start_time = time.monotonic() - print("OPENAI_API_TYPE", OPENAI_API_TYPE) + print("OPENAI_API_TYPE", os.environ['OPENAI_API_TYPE']) user_query_embedding = self.embeddings.embed_query(search_query) openai_embedding_latency = time.monotonic() - openai_start_time