diff --git a/src/fastmcp/tools/function_parsing.py b/src/fastmcp/tools/function_parsing.py index a056c37c9..c5ac20fff 100644 --- a/src/fastmcp/tools/function_parsing.py +++ b/src/fastmcp/tools/function_parsing.py @@ -63,8 +63,16 @@ class _UnserializableType: pass -def _is_object_schema(schema: dict[str, Any]) -> bool: +def _is_object_schema( + schema: dict[str, Any], + *, + _root_schema: dict[str, Any] | None = None, + _seen_refs: set[str] | None = None, +) -> bool: """Check if a JSON schema represents an object type.""" + root_schema = _root_schema or schema + seen_refs = _seen_refs or set() + # Direct object type if schema.get("type") == "object": return True @@ -73,9 +81,34 @@ def _is_object_schema(schema: dict[str, Any]) -> bool: if "properties" in schema: return True - # Self-referencing types use $ref pointing to $defs - # The referenced type is always an object in our use case - return "$ref" in schema and "$defs" in schema + # Resolve local $ref definitions and recurse into the target schema. + ref = schema.get("$ref") + if not isinstance(ref, str) or not ref.startswith("#/"): + return False + + if ref in seen_refs: + return False + + # Walk the JSON Pointer path from the root schema, unescaping each + # token per RFC 6901 (~1 → /, ~0 → ~). + pointer = ref.removeprefix("#/") + segments = pointer.split("/") + target: Any = root_schema + for segment in segments: + unescaped = segment.replace("~1", "/").replace("~0", "~") + if not isinstance(target, dict) or unescaped not in target: + return False + target = target[unescaped] + + target_schema = target + if not isinstance(target_schema, dict): + return False + + return _is_object_schema( + target_schema, + _root_schema=root_schema, + _seen_refs=seen_refs | {ref}, + ) @dataclass diff --git a/tests/server/providers/local_provider_tools/test_output_schema.py b/tests/server/providers/local_provider_tools/test_output_schema.py index 6f1f38d44..f9136492e 100644 --- a/tests/server/providers/local_provider_tools/test_output_schema.py +++ b/tests/server/providers/local_provider_tools/test_output_schema.py @@ -1,16 +1,17 @@ """Tests for tool output schemas.""" from dataclasses import dataclass -from typing import Any +from typing import Any, Literal import pytest from mcp.types import ( TextContent, ) from pydantic import AnyUrl, BaseModel, TypeAdapter -from typing_extensions import TypedDict +from typing_extensions import TypeAliasType, TypedDict from fastmcp import FastMCP +from fastmcp.tools.function_parsing import _is_object_schema from fastmcp.tools.tool import ToolResult from fastmcp.utilities.json_schema import compress_schema @@ -282,3 +283,99 @@ def edge_case_tool() -> tuple[int, str]: result = await mcp.call_tool("edge_case_tool", {}) assert result.structured_content == {"result": [42, "hello"]} + + async def test_output_schema_wraps_non_object_ref_schema(self): + """Root $ref schemas should only skip wrapping when they resolve to objects.""" + mcp = FastMCP() + AliasType = TypeAliasType("AliasType", Literal["foo", "bar"]) + + @mcp.tool + def alias_tool() -> AliasType: + return "foo" + + tools = await mcp.list_tools() + tool = next(t for t in tools if t.name == "alias_tool") + + expected_inner_schema = compress_schema( + TypeAdapter(AliasType).json_schema(mode="serialization"), + prune_titles=True, + ) + assert tool.output_schema == { + "type": "object", + "properties": {"result": expected_inner_schema}, + "required": ["result"], + "x-fastmcp-wrap-result": True, + } + + result = await mcp.call_tool("alias_tool", {}) + assert result.structured_content == {"result": "foo"} + + +class TestIsObjectSchemaRefResolution: + """Tests for $ref resolution in _is_object_schema, including JSON Pointer + escaping and nested $defs paths.""" + + def test_simple_ref_to_object(self): + schema = { + "$ref": "#/$defs/MyModel", + "$defs": { + "MyModel": {"type": "object", "properties": {"x": {"type": "int"}}} + }, + } + assert _is_object_schema(schema) is True + + def test_simple_ref_to_non_object(self): + schema = { + "$ref": "#/$defs/MyEnum", + "$defs": {"MyEnum": {"enum": ["a", "b"]}}, + } + assert _is_object_schema(schema) is False + + def test_nested_defs_path(self): + """Refs like #/$defs/Outer/$defs/Inner should walk into nested dicts.""" + schema = { + "$ref": "#/$defs/Outer/$defs/Inner", + "$defs": { + "Outer": { + "$defs": { + "Inner": { + "type": "object", + "properties": {"y": {"type": "string"}}, + }, + }, + }, + }, + } + assert _is_object_schema(schema) is True + + def test_nested_defs_non_object(self): + schema = { + "$ref": "#/$defs/Outer/$defs/Inner", + "$defs": { + "Outer": { + "$defs": { + "Inner": {"type": "string"}, + }, + }, + }, + } + assert _is_object_schema(schema) is False + + def test_json_pointer_tilde_escape(self): + """~0 should unescape to ~ and ~1 should unescape to /.""" + schema = { + "$ref": "#/$defs/has~1slash~0tilde", + "$defs": {"has/slash~tilde": {"type": "object", "properties": {}}}, + } + assert _is_object_schema(schema) is True + + def test_missing_nested_segment_returns_false(self): + schema = { + "$ref": "#/$defs/Outer/$defs/Missing", + "$defs": { + "Outer": { + "$defs": {}, + }, + }, + } + assert _is_object_schema(schema) is False