Skip to content

Commit

Permalink
Fixed Azure OpenAI Deprecations and Adjusted the Tests (#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
halanm authored Jun 24, 2024
1 parent 18fb92f commit 8700165
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion embedchain/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from langchain_community.embeddings import AzureOpenAIEmbeddings
from langchain_openai.embeddings import AzureOpenAIEmbeddings

from embedchain.config import BaseEmbedderConfig
from embedchain.embedder.base import BaseEmbedder
Expand Down
8 changes: 4 additions & 4 deletions embedchain/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config=config)

def get_llm_model_answer(self, prompt):
return AzureOpenAILlm._get_answer(prompt=prompt, config=self.config)
return self._get_answer(prompt=prompt, config=self.config)

@staticmethod
def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
from langchain_community.chat_models import AzureChatOpenAI
from langchain_openai import AzureChatOpenAI

if not config.deployment_name:
raise ValueError("Deployment name must be provided for Azure OpenAI")

chat = AzureChatOpenAI(
deployment_name=config.deployment_name,
openai_api_version=str(config.api_version) if config.api_version else "2023-05-15",
openai_api_version=str(config.api_version) if config.api_version else "2024-02-01",
model_name=config.model or "gpt-3.5-turbo",
temperature=config.temperature,
max_tokens=config.max_tokens,
Expand All @@ -37,4 +37,4 @@ def _get_answer(prompt: str, config: BaseLlmConfig) -> str:

messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)

return chat(messages).content
return chat.invoke(messages).content
15 changes: 6 additions & 9 deletions tests/llm/test_azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,22 @@ def test_get_llm_model_answer(azure_openai_llm):


def test_get_answer(azure_openai_llm):
with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
mock_chat_instance = mock_chat.return_value
mock_chat_instance.return_value = MagicMock(content="Test Response")
mock_chat_instance.invoke.return_value = MagicMock(content="Test Response")

prompt = "Test Prompt"
response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config)

assert response == "Test Response"
mock_chat.assert_called_once_with(
deployment_name=azure_openai_llm.config.deployment_name,
openai_api_version="2023-05-15",
openai_api_version="2024-02-01",
model_name=azure_openai_llm.config.model or "gpt-3.5-turbo",
temperature=azure_openai_llm.config.temperature,
max_tokens=azure_openai_llm.config.max_tokens,
streaming=azure_openai_llm.config.stream,
)
mock_chat_instance.assert_called_once_with(
azure_openai_llm._get_messages(prompt, system_prompt=azure_openai_llm.config.system_prompt)
)


def test_get_messages(azure_openai_llm):
Expand All @@ -65,6 +62,7 @@ def test_when_no_deployment_name_provided():
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")


def test_with_api_version():
config = BaseLlmConfig(
deployment_name="azure_deployment",
Expand All @@ -75,8 +73,7 @@ def test_with_api_version():
api_version="2024-02-01",
)

with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:

with patch("langchain_openai.AzureChatOpenAI") as mock_chat:
llm = AzureOpenAILlm(config)
llm.get_llm_model_answer("Test Prompt")

Expand All @@ -87,4 +84,4 @@ def test_with_api_version():
temperature=0.7,
max_tokens=50,
streaming=False,
)
)

0 comments on commit 8700165

Please sign in to comment.