Skip to content

Commit

Permalink
refactor and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
singhk97 committed Oct 30, 2024
1 parent b7d38f2 commit 3a74afe
Showing 1 changed file with 78 additions and 9 deletions.
87 changes: 78 additions & 9 deletions python/packages/ai/tests/ai/models/test_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,35 +100,32 @@ class MockAsyncCompletions:
should_error = False
has_tool_call = False
has_tool_calls = False
is_o1_model = False
messages = []
create_params = None

def __init__(
self, should_error=False, has_tool_call=False, has_tool_calls=False, is_o1_model=False
) -> None:
self.should_error = should_error
self.has_tool_call = has_tool_call
self.has_tool_calls = has_tool_calls
self.is_o1_model = is_o1_model
self.messages = []

async def create(self, **kwargs) -> chat.ChatCompletion:
self.create_params = kwargs

if self.should_error:
raise openai.BadRequestError(
"bad request",
response=httpx.Response(400, request=httpx.Request(method="method", url="url")),
body=None,
)

if self.has_tool_call:
return await self.handle_tool_call(**kwargs)

if self.has_tool_calls:
return await self.handle_tool_calls(**kwargs)

if self.is_o1_model:
self.messages = kwargs["messages"]

return chat.ChatCompletion(
id="",
choices=[
Expand Down Expand Up @@ -294,7 +291,6 @@ async def test_should_be_success(self, mock_async_openai):

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_o1_model_should_use_user_message_over_system_message(self, mock_async_openai):
mock_async_openai.return_value.chat.completions.is_o1_model = True
context = self.create_mock_context()
state = TurnState()
state.temp = {}
Expand All @@ -319,8 +315,81 @@ async def test_o1_model_should_use_user_message_over_system_message(self, mock_a

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
create_params = mock_async_openai.return_value.chat.completions.create_params
self.assertEqual(
create_params["messages"][0]["role"], "user"
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_o1_model_should_use_max_completion_tokens_param(self, mock_async_openai):
context = self.create_mock_context()
state = TurnState()
state.temp = {}
state.conversation = {}
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="o1-"))
completion = CompletionConfig(completion_type="chat")
completion.max_tokens = 1000
res = await model.complete_prompt(
context=context,
memory=state,
functions=cast(PromptFunctions, {}),
tokenizer=GPTTokenizer(),
template=PromptTemplate(
name="default",
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
config=PromptTemplateConfig(
schema=1.0,
type="completion",
description="test",
completion=completion,
),
),
)

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
create_params = mock_async_openai.return_value.chat.completions.create_params
self.assertEqual(
create_params["max_completion_tokens"], 1000
)
self.assertEqual(
create_params["max_tokens"], openai.NOT_GIVEN
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_non_o1_model_should_use_max_tokens_param(self, mock_async_openai):
context = self.create_mock_context()
state = TurnState()
state.temp = {}
state.conversation = {}
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="non-o1"))
completion = CompletionConfig(completion_type="chat")
completion.max_tokens = 1000
res = await model.complete_prompt(
context=context,
memory=state,
functions=cast(PromptFunctions, {}),
tokenizer=GPTTokenizer(),
template=PromptTemplate(
name="default",
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
config=PromptTemplateConfig(
schema=1.0,
type="completion",
description="test",
completion=completion,
),
),
)

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
create_params = mock_async_openai.return_value.chat.completions.create_params
self.assertEqual(
create_params["max_tokens"], 1000
)
self.assertEqual(
mock_async_openai.return_value.chat.completions.messages[0]["role"], "user"
create_params["max_completion_tokens"], openai.NOT_GIVEN
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
Expand Down

0 comments on commit 3a74afe

Please sign in to comment.