From 8700165b420c97ce1b3745f19c088bbf1d89cb10 Mon Sep 17 00:00:00 2001 From: Halan Marques Date: Mon, 24 Jun 2024 10:55:38 -0700 Subject: [PATCH] Fixed Azure OpenAI Deprecations and Adjusted the Tests (#1437) --- embedchain/embedder/openai.py | 2 +- embedchain/llm/azure_openai.py | 8 ++++---- tests/llm/test_azure_openai.py | 15 ++++++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/embedchain/embedder/openai.py b/embedchain/embedder/openai.py index 88a875c743..ab361e2714 100644 --- a/embedchain/embedder/openai.py +++ b/embedchain/embedder/openai.py @@ -2,7 +2,7 @@ from typing import Optional from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction -from langchain_community.embeddings import AzureOpenAIEmbeddings +from langchain_openai.embeddings import AzureOpenAIEmbeddings from embedchain.config import BaseEmbedderConfig from embedchain.embedder.base import BaseEmbedder diff --git a/embedchain/llm/azure_openai.py b/embedchain/llm/azure_openai.py index 6c5a03f1c8..f32b393064 100644 --- a/embedchain/llm/azure_openai.py +++ b/embedchain/llm/azure_openai.py @@ -14,18 +14,18 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): super().__init__(config=config) def get_llm_model_answer(self, prompt): - return AzureOpenAILlm._get_answer(prompt=prompt, config=self.config) + return self._get_answer(prompt=prompt, config=self.config) @staticmethod def _get_answer(prompt: str, config: BaseLlmConfig) -> str: - from langchain_community.chat_models import AzureChatOpenAI + from langchain_openai import AzureChatOpenAI if not config.deployment_name: raise ValueError("Deployment name must be provided for Azure OpenAI") chat = AzureChatOpenAI( deployment_name=config.deployment_name, - openai_api_version=str(config.api_version) if config.api_version else "2023-05-15", + openai_api_version=str(config.api_version) if config.api_version else "2024-02-01", model_name=config.model or "gpt-3.5-turbo", temperature=config.temperature, max_tokens=config.max_tokens, @@ -37,4 +37,4 @@ def _get_answer(prompt: str, config: BaseLlmConfig) -> str: messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt) - return chat(messages).content + return chat.invoke(messages).content diff --git a/tests/llm/test_azure_openai.py b/tests/llm/test_azure_openai.py index b936a3025d..853fbf732a 100644 --- a/tests/llm/test_azure_openai.py +++ b/tests/llm/test_azure_openai.py @@ -28,9 +28,9 @@ def test_get_llm_model_answer(azure_openai_llm): def test_get_answer(azure_openai_llm): - with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat: + with patch("langchain_openai.AzureChatOpenAI") as mock_chat: mock_chat_instance = mock_chat.return_value - mock_chat_instance.return_value = MagicMock(content="Test Response") + mock_chat_instance.invoke.return_value = MagicMock(content="Test Response") prompt = "Test Prompt" response = azure_openai_llm._get_answer(prompt, azure_openai_llm.config) @@ -38,15 +38,12 @@ def test_get_answer(azure_openai_llm): assert response == "Test Response" mock_chat.assert_called_once_with( deployment_name=azure_openai_llm.config.deployment_name, - openai_api_version="2023-05-15", + openai_api_version="2024-02-01", model_name=azure_openai_llm.config.model or "gpt-3.5-turbo", temperature=azure_openai_llm.config.temperature, max_tokens=azure_openai_llm.config.max_tokens, streaming=azure_openai_llm.config.stream, ) - mock_chat_instance.assert_called_once_with( - azure_openai_llm._get_messages(prompt, system_prompt=azure_openai_llm.config.system_prompt) - ) def test_get_messages(azure_openai_llm): @@ -65,6 +62,7 @@ def test_when_no_deployment_name_provided(): llm = AzureOpenAILlm(config) llm.get_llm_model_answer("Test Prompt") + def test_with_api_version(): config = BaseLlmConfig( deployment_name="azure_deployment", @@ -75,8 +73,7 @@ def test_with_api_version(): api_version="2024-02-01", ) - with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat: - + with patch("langchain_openai.AzureChatOpenAI") as mock_chat: llm = AzureOpenAILlm(config) llm.get_llm_model_answer("Test Prompt") @@ -87,4 +84,4 @@ def test_with_api_version(): temperature=0.7, max_tokens=50, streaming=False, - ) \ No newline at end of file + )