diff --git a/src/mcp_agent/llm/providers/augmented_llm_deepseek.py b/src/mcp_agent/llm/providers/augmented_llm_deepseek.py index 4e6b5392b..26c3ed664 100644 --- a/src/mcp_agent/llm/providers/augmented_llm_deepseek.py +++ b/src/mcp_agent/llm/providers/augmented_llm_deepseek.py @@ -1,6 +1,10 @@ +from typing import List, Tuple, Type + from mcp_agent.core.request_params import RequestParams from mcp_agent.llm.provider_types import Provider from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM +from mcp_agent.mcp.interfaces import ModelT +from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart DEEPSEEK_BASE_URL = "https://api.deepseek.com" DEFAULT_DEEPSEEK_MODEL = "deepseekchat" # current Deepseek only has two type models @@ -20,6 +24,9 @@ def _initialize_default_params(self, kwargs: dict) -> RequestParams: parallel_tool_calls=True, max_iterations=10, use_history=True, + response_format={ + "type": "json_object" + } ) def _base_url(self) -> str: @@ -28,3 +35,46 @@ def _base_url(self) -> str: base_url = self.context.config.deepseek.base_url return base_url if base_url else DEEPSEEK_BASE_URL + + async def _apply_prompt_provider_specific_structured( + self, + multipart_messages: List[PromptMessageMultipart], + model: Type[ModelT], + request_params: RequestParams | None = None, + ) -> Tuple[ModelT | None, PromptMessageMultipart]: # noqa: F821 + request_params = self.get_request_params(request_params) + + # Get the full schema and extract just the properties + full_schema = model.model_json_schema() + properties = full_schema.get("properties", {}) + required_fields = full_schema.get("required", []) + + # Create a cleaner format description + format_description = "{\n" + for field_name, field_info in properties.items(): + field_type = field_info.get("type", "string") + description = field_info.get("description", "") + format_description += f' "{field_name}": "{field_type}"' + if description: + format_description += f' // {description}' + if field_name in required_fields: + format_description += ' // REQUIRED' + format_description += '\n' + format_description += "}" + + multipart_messages[-1].add_text( + f"""YOU MUST RESPOND WITH A JSON OBJECT IN EXACTLY THIS FORMAT: + {format_description} + + IMPORTANT RULES: + - Respond ONLY with the JSON object, no other text + - Do NOT include "properties" or "schema" wrappers + - Do NOT use code fences or markdown + - The response must be valid JSON that matches the format above + - All required fields must be included""" + ) + + result: PromptMessageMultipart = await self._apply_prompt_provider_specific( + multipart_messages, request_params + ) + return self._structured_from_multipart(result, model) diff --git a/tests/e2e/structured/test_structured_outputs.py b/tests/e2e/structured/test_structured_outputs.py index d768f1eec..30fd05f72 100644 --- a/tests/e2e/structured/test_structured_outputs.py +++ b/tests/e2e/structured/test_structured_outputs.py @@ -22,8 +22,9 @@ class FormattedResponse(BaseModel): @pytest.mark.parametrize( "model_name", [ - "generic.qwen2.5:latest", - "generic.llama3.2:latest", + # "generic.qwen2.5:latest", + # "generic.llama3.2:latest", + "deepseek-chat", "haiku", "sonnet", "gpt-4.1",