diff --git a/src/any_llm/providers/mistral/mistral.py b/src/any_llm/providers/mistral/mistral.py index 4ab7fa51..b81bde90 100644 --- a/src/any_llm/providers/mistral/mistral.py +++ b/src/any_llm/providers/mistral/mistral.py @@ -10,6 +10,7 @@ try: from mistralai import Mistral from mistralai.extra import response_format_from_pydantic_model + from mistralai.models.responseformat import ResponseFormat from .utils import ( _convert_models_list, @@ -115,12 +116,13 @@ async def _acompletion( if params.reasoning_effort == "auto": params.reasoning_effort = None - if ( - params.response_format is not None - and isinstance(params.response_format, type) - and issubclass(params.response_format, BaseModel) - ): - kwargs["response_format"] = response_format_from_pydantic_model(params.response_format) + if params.response_format is not None: + # Pydantic model + if isinstance(params.response_format, type) and issubclass(params.response_format, BaseModel): + kwargs["response_format"] = response_format_from_pydantic_model(params.response_format) + # Dictionary in OpenAI format + elif isinstance(params.response_format, dict): + kwargs["response_format"] = ResponseFormat.model_validate(params.response_format) completion_kwargs = self._convert_completion_params(params, **kwargs) diff --git a/tests/unit/providers/test_mistral_provider.py b/tests/unit/providers/test_mistral_provider.py index 268fea59..4acc8ca8 100644 --- a/tests/unit/providers/test_mistral_provider.py +++ b/tests/unit/providers/test_mistral_provider.py @@ -1,6 +1,11 @@ from typing import Any +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from pydantic import BaseModel from any_llm.providers.mistral.utils import _patch_messages +from any_llm.types.completion import CompletionParams def test_patch_messages_noop_when_no_tool_before_user() -> None: @@ -91,3 +96,71 @@ def test_patch_messages_with_multiple_valid_tool_calls() -> None: {"role": "assistant", "content": "OK"}, {"role": "user", "content": "u1"}, ] + + +class StructuredOutput(BaseModel): + foo: str + bar: int + + +openai_json_schema = { + "type": "json_schema", + "json_schema": { + "name": "StructuredOutput", + "schema": {**StructuredOutput.model_json_schema(), "additionalProperties": False}, + "strict": True, + }, +} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "response_format", + [ + StructuredOutput, + openai_json_schema, + ], + ids=["pydantic_model", "openai_json_schema"], +) +async def test_response_format(response_format: Any) -> None: + """Test that response_format is properly converted for both Pydantic and dict formats.""" + mistralai = pytest.importorskip("mistralai") + from any_llm.providers.mistral.mistral import MistralProvider + + with ( + patch("any_llm.providers.mistral.mistral.Mistral") as mocked_mistral, + patch("any_llm.providers.mistral.mistral._create_mistral_completion_from_response") as mock_converter, + ): + provider = MistralProvider(api_key="test-api-key") + + mocked_mistral.return_value.chat.complete_async = AsyncMock(return_value=Mock()) + mock_converter.return_value = Mock() + + await provider._acompletion( + CompletionParams( + model_id="test-model", + messages=[{"role": "user", "content": "Hello"}], + response_format=response_format, + ), + ) + + completion_call_kwargs = mocked_mistral.return_value.chat.complete_async.call_args[1] + assert "response_format" in completion_call_kwargs + + response_format_arg = completion_call_kwargs["response_format"] + assert isinstance(response_format_arg, mistralai.models.responseformat.ResponseFormat) + assert response_format_arg.type == "json_schema" + assert response_format_arg.json_schema.name == "StructuredOutput" + assert response_format_arg.json_schema.strict is True + + expected_schema = { + "properties": { + "foo": {"title": "Foo", "type": "string"}, + "bar": {"title": "Bar", "type": "integer"}, + }, + "required": ["foo", "bar"], + "title": "StructuredOutput", + "type": "object", + "additionalProperties": False, + } + assert response_format_arg.json_schema.schema_definition == expected_schema