From 1398ed4e6ddfda8e80056654c97852374d9a8333 Mon Sep 17 00:00:00 2001 From: Paul Yu Date: Thu, 16 May 2024 21:28:20 -0700 Subject: [PATCH] feat: adding additional env vars for oai endpoint in ai-service --- src/ai-service/main.py | 2 +- src/ai-service/routers/image_generator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ai-service/main.py b/src/ai-service/main.py index 4021d13..302baae 100644 --- a/src/ai-service/main.py +++ b/src/ai-service/main.py @@ -19,7 +19,7 @@ async def get_health(): capabilities = ["description"] # Check if the environment variable is set - if os.environ.get("AZURE_OPENAI_DALLE_ENDPOINT") and os.environ.get("AZURE_OPENAI_DALLE_DEPLOYMENT_NAME"): + if (os.environ.get("AZURE_OPENAI_DALLE_ENDPOINT") or os.environ.get("AZURE_OPENAI_ENDPOINT")) and os.environ.get("AZURE_OPENAI_DALLE_DEPLOYMENT_NAME"): # If it is, add "image" to the array capabilities.append("image") diff --git a/src/ai-service/routers/image_generator.py b/src/ai-service/routers/image_generator.py index faed4b3..7dddd14 100644 --- a/src/ai-service/routers/image_generator.py +++ b/src/ai-service/routers/image_generator.py @@ -28,7 +28,7 @@ async def post_image(request: Request) -> JSONResponse: print("Calling OpenAI") api_version = os.environ.get("AZURE_OPENAI_API_VERSION") - endpoint = os.environ.get("AZURE_OPENAI_DALLE_ENDPOINT") + endpoint = os.environ.get("AZURE_OPENAI_DALLE_ENDPOINT") or os.environ.get("AZURE_OPENAI_ENDPOINT") model_deployment_name = os.environ.get("AZURE_OPENAI_DALLE_DEPLOYMENT_NAME") token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")