diff --git a/tests/system/llm/test_openai.py b/tests/system/llm/test_openai.py index e6c334c8..0933f4c6 100644 --- a/tests/system/llm/test_openai.py +++ b/tests/system/llm/test_openai.py @@ -4,7 +4,7 @@ import jsonschema import pytest -from canopy.llm import AzureOpenAILLM, AnyscaleLLM +from canopy.llm import AzureOpenAILLM, AnyscaleLLM, OctoAILLM from canopy.models.data_models import Role, MessageBase, Context, StringContextContent # noqa from canopy.models.api_models import ChatResponse, StreamingChatChunk # noqa from canopy.llm.openai import OpenAILLM # noqa @@ -60,7 +60,7 @@ def model_params_low_temperature(): return {"temperature": 0.2, "top_p": 0.5, "n": 1} -@pytest.fixture(params=[OpenAILLM, AzureOpenAILLM, AnyscaleLLM]) +@pytest.fixture(params=[OpenAILLM, AzureOpenAILLM, AnyscaleLLM, OctoAILLM]) def openai_llm(request): llm_class = request.param if llm_class == AzureOpenAILLM: @@ -73,6 +73,10 @@ def openai_llm(request): if os.getenv("ANYSCALE_API_KEY") is None: pytest.skip("Couldn't find Anyscale API key. Skipping Anyscale tests.") model_name = "mistralai/Mistral-7B-Instruct-v0.1" + elif llm_class == OctoAILLM: + if os.getenv("OCTOAI_API_KEY") is None: + pytest.skip[("Couldn't find OctoAI API key. Skipping OctoAI tests.")] + model_name = "mistral-7b-instruct" else: model_name = "gpt-3.5-turbo-0613" @@ -121,6 +125,8 @@ def test_chat_completion_with_context(openai_llm, messages): def test_enforced_function_call(openai_llm, messages, function_query_knowledgebase): + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI doesn't support function calling at the moment") result = openai_llm.enforced_function_call( system_prompt=SYSTEM_PROMPT, chat_history=messages, @@ -134,11 +140,15 @@ def test_chat_completion_high_temperature(openai_llm, if isinstance(openai_llm, AnyscaleLLM): pytest.skip("Anyscale don't support n>1 for the moment.") + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI doesn't support n>1 for the moment.") + response = openai_llm.chat_completion( system_prompt=SYSTEM_PROMPT, chat_history=messages, model_params=model_params_high_temperature ) + assert_chat_completion(response, num_choices=model_params_high_temperature["n"]) @@ -159,6 +169,9 @@ def test_enforced_function_call_high_temperature(openai_llm, model_params_high_temperature): if isinstance(openai_llm, AnyscaleLLM): pytest.skip("Anyscale don't support n>1 for the moment.") + + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI doesn't support function calling at the moment") result = openai_llm.enforced_function_call( system_prompt=SYSTEM_PROMPT, @@ -176,6 +189,9 @@ def test_enforced_function_call_low_temperature(openai_llm, model_params = model_params_low_temperature.copy() if isinstance(openai_llm, AnyscaleLLM): model_params["top_p"] = 1.0 + + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI doesn't support function calling at the moment") result = openai_llm.enforced_function_call( system_prompt=SYSTEM_PROMPT, @@ -191,6 +207,8 @@ def test_chat_completion_with_model_name(openai_llm, messages): pytest.skip("In Azure the model name has to be a valid deployment") elif isinstance(openai_llm, AnyscaleLLM): new_model_name = "meta-llama/Llama-2-7b-chat-hf" + elif isinstance(openai_llm, OctoAILLM): + new_model_name = "codellama-7b-instruct" else: new_model_name = "gpt-3.5-turbo-1106" @@ -248,6 +266,9 @@ def test_chat_complete_api_failure_populates(openai_llm, def test_enforce_function_api_failure_populates(openai_llm, messages, function_query_knowledgebase): + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI doesn't support function calling at the moment") + openai_llm._client = MagicMock() openai_llm._client.chat.completions.create.side_effect = Exception( "API call failed") @@ -261,6 +282,9 @@ def test_enforce_function_api_failure_populates(openai_llm, def test_enforce_function_wrong_output_schema(openai_llm, messages, function_query_knowledgebase): + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI doesn't support function calling at the moment") + openai_llm._client = MagicMock() openai_llm._client.chat.completions.create.return_value = MagicMock( choices=[MagicMock( @@ -302,6 +326,8 @@ def test_enforce_function_unsupported_model(openai_llm, def test_available_models(openai_llm): if isinstance(openai_llm, AzureOpenAILLM): pytest.skip("Azure does not support listing models") + if isinstance(openai_llm, OctoAILLM): + pytest.skip("OctoAI does not support listing models") models = openai_llm.available_models assert isinstance(models, list) assert len(models) > 0