Skip to content

[PY] feat: o1 model support #2123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,273 changes: 1,293 additions & 980 deletions python/packages/ai/poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/packages/ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dataclasses-json = "^0.6.4"
azure-ai-contentsafety = "^1.0.0"
msal = "^1.28.0"
botbuilder-dialogs = "^4.14.8"
openai = "^1.27.0"
openai = "^v1.52.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
Expand Down
47 changes: 37 additions & 10 deletions python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ class AzureOpenAIModelOptions:
class OpenAIModel(PromptCompletionModel):
"""
A `PromptCompletionModel` for calling OpenAI and Azure OpenAI hosted models.

The model has been updated to support calling OpenAI's new o1 family of models. That currently
comes with a few constraints. These constraints are mostly handled for you but are worth noting:

- The o1 models introduce a new `max_completion_tokens` parameter and they've deprecated
the `max_tokens` parameter. The model will automatically convert the incoming `max_tokens
` parameter to `max_completion_tokens` for you. But you should be aware that o1 has hidden
token usage and costs that aren't constrained by the `max_completion_tokens` parameter.
This means that you may see an increase in token usage and costs when using the o1 models.

- The o1 models do not currently support the sending of system message so the model will
map them to user message in this case.

- The o1 models do not currently support setting the `temperature`, `top_p`, and
`presence_penalty` parameters so they will be ignored.

- The o1 models do not currently support the use of tools so you will need to use the
"monologue" augmentation to call actions.
"""

_options: Union[OpenAIModelOptions, AzureOpenAIModelOptions]
Expand Down Expand Up @@ -162,6 +180,7 @@ async def complete_prompt(
if template.config.completion.model is not None
else self._options.default_model
)
is_o1_model = model.startswith("o1-")

res = await template.prompt.render_as_messages(
context=context,
Expand Down Expand Up @@ -232,10 +251,17 @@ async def complete_prompt(
content=msg.content if msg.content else "",
)
elif msg.role == "system":
param = chat.ChatCompletionSystemMessageParam(
role="system",
content=msg.content if msg.content is not None else "",
)
# o1 models do not support system messages
if is_o1_model:
param = chat.ChatCompletionUserMessageParam(
role="user",
content=msg.content if msg.content is not None else "",
)
else:
param = chat.ChatCompletionSystemMessageParam(
role="system",
content=msg.content if msg.content is not None else "",
)

if msg.name:
param["name"] = msg.name
Expand All @@ -250,11 +276,13 @@ async def complete_prompt(
completion = await self._client.chat.completions.create(
messages=messages,
model=model,
presence_penalty=template.config.completion.presence_penalty,
presence_penalty=(
template.config.completion.presence_penalty if not is_o1_model else 0
),
frequency_penalty=template.config.completion.frequency_penalty,
top_p=template.config.completion.top_p,
temperature=template.config.completion.temperature,
max_tokens=template.config.completion.max_tokens,
top_p=template.config.completion.top_p if not is_o1_model else 1,
temperature=template.config.completion.temperature if not is_o1_model else 1,
max_completion_tokens=template.config.completion.max_tokens,
tools=tools if len(tools) > 0 else NOT_GIVEN,
tool_choice=tool_choice if len(tools) > 0 else NOT_GIVEN,
parallel_tool_calls=parallel_tool_calls if len(tools) > 0 else NOT_GIVEN,
Expand All @@ -266,8 +294,7 @@ async def complete_prompt(

# Handle tools flow
action_calls = []
response_message = completion.choices[0].message
tool_calls = response_message.tool_calls
tool_calls = completion.choices[0].message.tool_calls

if is_tools_aug and tool_calls:
for curr_tool_call in tool_calls:
Expand Down
43 changes: 42 additions & 1 deletion python/packages/ai/tests/ai/models/test_openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from teams.ai.prompts.sections.conversation_history_section import (
ConversationHistorySection,
)
from teams.ai.prompts.sections.template_section import TemplateSection
from teams.ai.tokenizers import GPTTokenizer
from teams.state import TurnState

Expand Down Expand Up @@ -99,11 +100,17 @@ class MockAsyncCompletions:
should_error = False
has_tool_call = False
has_tool_calls = False
is_o1_model = False
messages = []

def __init__(self, should_error=False, has_tool_call=False, has_tool_calls=False) -> None:
def __init__(
self, should_error=False, has_tool_call=False, has_tool_calls=False, is_o1_model=False
) -> None:
self.should_error = should_error
self.has_tool_call = has_tool_call
self.has_tool_calls = has_tool_calls
self.is_o1_model = is_o1_model
self.messages = []

async def create(self, **kwargs) -> chat.ChatCompletion:
if self.should_error:
Expand All @@ -119,6 +126,9 @@ async def create(self, **kwargs) -> chat.ChatCompletion:
if self.has_tool_calls:
return await self.handle_tool_calls(**kwargs)

if self.is_o1_model:
self.messages = kwargs["messages"]

return chat.ChatCompletion(
id="",
choices=[
Expand Down Expand Up @@ -282,6 +292,37 @@ async def test_should_be_success(self, mock_async_openai):
self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_o1_model_should_use_user_message_over_system_message(self, mock_async_openai):
mock_async_openai.return_value.chat.completions.is_o1_model = True
context = self.create_mock_context()
state = TurnState()
state.temp = {}
state.conversation = {}
model = OpenAIModel(OpenAIModelOptions(api_key="", default_model="o1-"))
res = await model.complete_prompt(
context=context,
memory=state,
functions=cast(PromptFunctions, {}),
tokenizer=GPTTokenizer(),
template=PromptTemplate(
name="default",
prompt=Prompt(sections=[TemplateSection("prompt text", "system")]),
config=PromptTemplateConfig(
schema=1.0,
type="completion",
description="test",
completion=CompletionConfig(completion_type="chat"),
),
),
)

self.assertTrue(mock_async_openai.called)
self.assertEqual(res.status, "success")
self.assertEqual(
mock_async_openai.return_value.chat.completions.messages[0]["role"], "user"
)

@mock.patch("openai.AsyncOpenAI", return_value=MockAsyncOpenAI)
async def test_should_succeed_on_prev_tool_calls(self, mock_async_openai):
context = self.create_mock_context()
Expand Down
22 changes: 22 additions & 0 deletions python/packages/ai/tests/ai/moderators/test_openai_moderator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,33 @@ async def create(
model=model,
results=[
openai.types.Moderation(
category_applied_input_types=cast(
openai.types.moderation.CategoryAppliedInputTypes,
{
"harassment": ["text"],
"harassment/threatening": ["text"],
"hate": ["text"],
"hate/threatening": ["text"],
"illicit": ["text"],
"illicit/violent": ["text"],
"self-harm": ["text"],
"self-harm/instructions": ["text"],
"self-harm/intent": ["text"],
"sexual": ["text"],
"sexual/minors": ["text"],
"violence": ["text"],
"violence/graphic": ["text"],
},
),
categories=cast(
openai.types.moderation.Categories,
{
"harassment": True,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"illicit": False,
"illicit/violent": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
Expand All @@ -80,6 +100,8 @@ async def create(
"harassment/threatening": 0,
"hate": 0,
"hate/threatening": 0,
"illicit": 0,
"illicit/violent": 0,
"self-harm": 0,
"self-harm/instructions": 0,
"self-harm/intent": 0,
Expand Down
4 changes: 2 additions & 2 deletions python/samples/06.assistants.b.orderBot/src/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ async def turn_state_factory(context: TurnContext):

@app.ai.action("place_order")
async def on_place_order(
context: ActionTurnContext[Order],
context: ActionTurnContext,
state: AppTurnState,
):
card = generate_card_for_order(context.data)
card = generate_card_for_order(Order.from_dict(context.data))
await context.send_activity(MessageFactory.attachment(card))
return "order placed"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
class Pizza(DataClassJsonMixin):
item_type = "pizza"

added_toppings: Optional[List[str]]
added_toppings: Optional[List[str]] = None
"Toppings requested (examples: pepperoni, arugula)"

removed_toppings: Optional[List[str]]
removed_toppings: Optional[List[str]] = None
"Toppings requested to be removed (examples: fresh garlic, anchovies)"

name: Optional[PizzaName]
name: Optional[PizzaName] = None
"Used if the requester references a pizza by name"

size: Optional[PizzaSize] = "large"
Expand Down Expand Up @@ -69,10 +69,10 @@ class Beer(DataClassJsonMixin):
class Salad(DataClassJsonMixin):
item_type = "salad"

added_ingredients: Optional[List[str]]
added_ingredients: Optional[List[str]] = None
"Ingredients requested (examples: parmesan, croutons)"

removed_ingredients: Optional[List[str]]
removed_ingredients: Optional[List[str]] = None
"Ingredients requested to be removed (example: red onions)"

portion: Optional[SaladSize] = "half"
Expand All @@ -89,4 +89,4 @@ class Order(DataClassJsonMixin):
An order from a restaurant that serves pizza, beer, and salad
"""

items: List[Union[Pizza, Beer, Salad, NamedPizza, UnknownText]]
items: List[Union[Pizza, Beer, Salad, NamedPizza, UnknownText]]
Loading