|
4 | 4 |
|
5 | 5 | import inspect |
6 | 6 | import re |
7 | | -import types |
8 | 7 | from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence |
9 | 8 | from contextlib import ( |
10 | 9 | AbstractAsyncContextManager, |
11 | 10 | asynccontextmanager, |
12 | 11 | ) |
13 | 12 | from itertools import chain |
14 | | -from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin |
| 13 | +from typing import Any, Generic, Literal |
15 | 14 |
|
16 | 15 | import anyio |
17 | 16 | import pydantic_core |
18 | | -from pydantic import BaseModel, Field, ValidationError |
19 | | -from pydantic.fields import FieldInfo |
| 17 | +from pydantic import BaseModel, Field |
20 | 18 | from pydantic.networks import AnyUrl |
21 | 19 | from pydantic_settings import BaseSettings, SettingsConfigDict |
22 | 20 | from starlette.applications import Starlette |
|
36 | 34 | from mcp.server.auth.settings import ( |
37 | 35 | AuthSettings, |
38 | 36 | ) |
| 37 | +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation |
39 | 38 | from mcp.server.fastmcp.exceptions import ResourceError |
40 | 39 | from mcp.server.fastmcp.prompts import Prompt, PromptManager |
41 | 40 | from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager |
|
67 | 66 |
|
68 | 67 | logger = get_logger(__name__) |
69 | 68 |
|
70 | | -ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) |
71 | | - |
72 | | - |
73 | | -class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): |
74 | | - """Result of an elicitation request.""" |
75 | | - |
76 | | - action: Literal["accept", "decline", "cancel"] |
77 | | - """The user's action in response to the elicitation.""" |
78 | | - |
79 | | - data: ElicitSchemaModelT | None = None |
80 | | - """The validated data if action is 'accept', None otherwise.""" |
81 | | - |
82 | | - validation_error: str | None = None |
83 | | - """Validation error message if data failed to validate.""" |
84 | | - |
85 | 69 |
|
86 | 70 | class Settings(BaseSettings, Generic[LifespanResultT]): |
87 | 71 | """FastMCP server settings. |
@@ -875,43 +859,6 @@ def _convert_to_content( |
875 | 859 | return [TextContent(type="text", text=result)] |
876 | 860 |
|
877 | 861 |
|
878 | | -# Primitive types allowed in elicitation schemas |
879 | | -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) |
880 | | - |
881 | | - |
882 | | -def _validate_elicitation_schema(schema: type[BaseModel]) -> None: |
883 | | - """Validate that a Pydantic model only contains primitive field types.""" |
884 | | - for field_name, field_info in schema.model_fields.items(): |
885 | | - if not _is_primitive_field(field_info): |
886 | | - raise TypeError( |
887 | | - f"Elicitation schema field '{field_name}' must be a primitive type " |
888 | | - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " |
889 | | - f"Complex types like lists, dicts, or nested models are not allowed." |
890 | | - ) |
891 | | - |
892 | | - |
893 | | -def _is_primitive_field(field_info: FieldInfo) -> bool: |
894 | | - """Check if a field is a primitive type allowed in elicitation schemas.""" |
895 | | - annotation = field_info.annotation |
896 | | - |
897 | | - # Handle None type |
898 | | - if annotation is types.NoneType: |
899 | | - return True |
900 | | - |
901 | | - # Handle basic primitive types |
902 | | - if annotation in _ELICITATION_PRIMITIVE_TYPES: |
903 | | - return True |
904 | | - |
905 | | - # Handle Union types |
906 | | - origin = get_origin(annotation) |
907 | | - if origin is Union or origin is types.UnionType: |
908 | | - args = get_args(annotation) |
909 | | - # All args must be primitive types or None |
910 | | - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) |
911 | | - |
912 | | - return False |
913 | | - |
914 | | - |
915 | 862 | class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): |
916 | 863 | """Context object providing access to MCP capabilities. |
917 | 864 |
|
@@ -1035,27 +982,10 @@ async def elicit( |
1035 | 982 | The result.data will only be populated if action is "accept" and validation succeeded. |
1036 | 983 | """ |
1037 | 984 |
|
1038 | | - # Validate that schema only contains primitive types and fail loudly if not |
1039 | | - _validate_elicitation_schema(schema) |
1040 | | - |
1041 | | - json_schema = schema.model_json_schema() |
1042 | | - |
1043 | | - result = await self.request_context.session.elicit( |
1044 | | - message=message, |
1045 | | - requestedSchema=json_schema, |
1046 | | - related_request_id=self.request_id, |
| 985 | + return await elicit_with_validation( |
| 986 | + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id |
1047 | 987 | ) |
1048 | 988 |
|
1049 | | - if result.action == "accept" and result.content: |
1050 | | - # Validate and parse the content using the schema |
1051 | | - try: |
1052 | | - validated_data = schema.model_validate(result.content) |
1053 | | - return ElicitationResult(action="accept", data=validated_data) |
1054 | | - except ValidationError as e: |
1055 | | - return ElicitationResult(action="accept", validation_error=str(e)) |
1056 | | - else: |
1057 | | - return ElicitationResult(action=result.action) |
1058 | | - |
1059 | 989 | async def log( |
1060 | 990 | self, |
1061 | 991 | level: Literal["debug", "info", "warning", "error"], |
|
0 commit comments