Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
added OctoAI to llm unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ptorru committed Mar 1, 2024
1 parent a36ac11 commit e7cdb53
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions tests/system/llm/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jsonschema
import pytest

from canopy.llm import AzureOpenAILLM, AnyscaleLLM
from canopy.llm import AzureOpenAILLM, AnyscaleLLM, OctoAILLM
from canopy.models.data_models import Role, MessageBase, Context, StringContextContent # noqa
from canopy.models.api_models import ChatResponse, StreamingChatChunk # noqa
from canopy.llm.openai import OpenAILLM # noqa
Expand Down Expand Up @@ -60,7 +60,7 @@ def model_params_low_temperature():
return {"temperature": 0.2, "top_p": 0.5, "n": 1}


@pytest.fixture(params=[OpenAILLM, AzureOpenAILLM, AnyscaleLLM])
@pytest.fixture(params=[OpenAILLM, AzureOpenAILLM, AnyscaleLLM, OctoAILLM])
def openai_llm(request):
llm_class = request.param
if llm_class == AzureOpenAILLM:
Expand All @@ -73,6 +73,10 @@ def openai_llm(request):
if os.getenv("ANYSCALE_API_KEY") is None:
pytest.skip("Couldn't find Anyscale API key. Skipping Anyscale tests.")
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
elif llm_class == OctoAILLM:
if os.getenv("OCTOAI_API_KEY") is None:
pytest.skip[("Couldn't find OctoAI API key. Skipping OctoAI tests.")]
model_name = "mistral-7b-instruct"
else:
model_name = "gpt-3.5-turbo-0613"

Expand Down Expand Up @@ -121,6 +125,8 @@ def test_chat_completion_with_context(openai_llm, messages):
def test_enforced_function_call(openai_llm,
messages,
function_query_knowledgebase):
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")
result = openai_llm.enforced_function_call(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
Expand All @@ -134,11 +140,15 @@ def test_chat_completion_high_temperature(openai_llm,
if isinstance(openai_llm, AnyscaleLLM):
pytest.skip("Anyscale don't support n>1 for the moment.")

if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support n>1 for the moment.")

response = openai_llm.chat_completion(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
model_params=model_params_high_temperature
)

assert_chat_completion(response,
num_choices=model_params_high_temperature["n"])

Expand All @@ -159,6 +169,9 @@ def test_enforced_function_call_high_temperature(openai_llm,
model_params_high_temperature):
if isinstance(openai_llm, AnyscaleLLM):
pytest.skip("Anyscale don't support n>1 for the moment.")

if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

result = openai_llm.enforced_function_call(
system_prompt=SYSTEM_PROMPT,
Expand All @@ -176,6 +189,9 @@ def test_enforced_function_call_low_temperature(openai_llm,
model_params = model_params_low_temperature.copy()
if isinstance(openai_llm, AnyscaleLLM):
model_params["top_p"] = 1.0

if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

result = openai_llm.enforced_function_call(
system_prompt=SYSTEM_PROMPT,
Expand All @@ -191,6 +207,8 @@ def test_chat_completion_with_model_name(openai_llm, messages):
pytest.skip("In Azure the model name has to be a valid deployment")
elif isinstance(openai_llm, AnyscaleLLM):
new_model_name = "meta-llama/Llama-2-7b-chat-hf"
elif isinstance(openai_llm, OctoAILLM):
new_model_name = "codellama-7b-instruct"
else:
new_model_name = "gpt-3.5-turbo-1106"

Expand Down Expand Up @@ -248,6 +266,9 @@ def test_chat_complete_api_failure_populates(openai_llm,
def test_enforce_function_api_failure_populates(openai_llm,
messages,
function_query_knowledgebase):
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

openai_llm._client = MagicMock()
openai_llm._client.chat.completions.create.side_effect = Exception(
"API call failed")
Expand All @@ -261,6 +282,9 @@ def test_enforce_function_api_failure_populates(openai_llm,
def test_enforce_function_wrong_output_schema(openai_llm,
messages,
function_query_knowledgebase):
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI doesn't support function calling at the moment")

openai_llm._client = MagicMock()
openai_llm._client.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(
Expand Down Expand Up @@ -302,6 +326,8 @@ def test_enforce_function_unsupported_model(openai_llm,
def test_available_models(openai_llm):
if isinstance(openai_llm, AzureOpenAILLM):
pytest.skip("Azure does not support listing models")
if isinstance(openai_llm, OctoAILLM):
pytest.skip("OctoAI does not support listing models")
models = openai_llm.available_models
assert isinstance(models, list)
assert len(models) > 0
Expand Down

0 comments on commit e7cdb53

Please sign in to comment.