Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/any_llm/providers/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
try:
from mistralai import Mistral
from mistralai.extra import response_format_from_pydantic_model
from mistralai.models.responseformat import ResponseFormat

from .utils import (
_convert_models_list,
Expand Down Expand Up @@ -115,12 +116,13 @@ async def _acompletion(
if params.reasoning_effort == "auto":
params.reasoning_effort = None

if (
params.response_format is not None
and isinstance(params.response_format, type)
and issubclass(params.response_format, BaseModel)
):
kwargs["response_format"] = response_format_from_pydantic_model(params.response_format)
if params.response_format is not None:
# Pydantic model
if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel):
kwargs["response_format"] = response_format_from_pydantic_model(params.response_format)
# Dictionary in OpenAI format
elif isinstance(params.response_format, dict):
kwargs["response_format"] = ResponseFormat.model_validate(params.response_format)

completion_kwargs = self._convert_completion_params(params, **kwargs)

Expand Down
73 changes: 73 additions & 0 deletions tests/unit/providers/test_mistral_provider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Any
from unittest.mock import AsyncMock, Mock, patch

import pytest
from pydantic import BaseModel

from any_llm.providers.mistral.utils import _patch_messages
from any_llm.types.completion import CompletionParams


def test_patch_messages_noop_when_no_tool_before_user() -> None:
Expand Down Expand Up @@ -91,3 +96,71 @@ def test_patch_messages_with_multiple_valid_tool_calls() -> None:
{"role": "assistant", "content": "OK"},
{"role": "user", "content": "u1"},
]


class StructuredOutput(BaseModel):
foo: str
bar: int


openai_json_schema = {
"type": "json_schema",
"json_schema": {
"name": "StructuredOutput",
"schema": {**StructuredOutput.model_json_schema(), "additionalProperties": False},
"strict": True,
},
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"response_format",
[
StructuredOutput,
openai_json_schema,
],
ids=["pydantic_model", "openai_json_schema"],
)
async def test_response_format(response_format: Any) -> None:
"""Test that response_format is properly converted for both Pydantic and dict formats."""
mistralai = pytest.importorskip("mistralai")
from any_llm.providers.mistral.mistral import MistralProvider

with (
patch("any_llm.providers.mistral.mistral.Mistral") as mocked_mistral,
patch("any_llm.providers.mistral.mistral._create_mistral_completion_from_response") as mock_converter,
):
provider = MistralProvider(api_key="test-api-key")

mocked_mistral.return_value.chat.complete_async = AsyncMock(return_value=Mock())
mock_converter.return_value = Mock()

await provider._acompletion(
CompletionParams(
model_id="test-model",
messages=[{"role": "user", "content": "Hello"}],
response_format=response_format,
),
)

completion_call_kwargs = mocked_mistral.return_value.chat.complete_async.call_args[1]
assert "response_format" in completion_call_kwargs

response_format_arg = completion_call_kwargs["response_format"]
assert isinstance(response_format_arg, mistralai.models.responseformat.ResponseFormat)
assert response_format_arg.type == "json_schema"
assert response_format_arg.json_schema.name == "StructuredOutput"
assert response_format_arg.json_schema.strict is True

expected_schema = {
"properties": {
"foo": {"title": "Foo", "type": "string"},
"bar": {"title": "Bar", "type": "integer"},
},
"required": ["foo", "bar"],
"title": "StructuredOutput",
"type": "object",
"additionalProperties": False,
}
assert response_format_arg.json_schema.schema_definition == expected_schema