Skip to content

Commit fa57eda

Browse files
committed
Merge branch 'main' into ihrpr/elicitation-prototype
2 parents 16440a9 + 0bcecff commit fa57eda

File tree

11 files changed

+44
-81
lines changed

11 files changed

+44
-81
lines changed

examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,7 @@ def main(
4141
app = Server("mcp-streamable-http-stateless-demo")
4242

4343
@app.call_tool()
44-
async def call_tool(
45-
name: str, arguments: dict
46-
) -> list[
47-
types.TextContent
48-
| types.ImageContent
49-
| types.AudioContent
50-
| types.EmbeddedResource
51-
]:
44+
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
5245
ctx = app.request_context
5346
interval = arguments.get("interval", 1.0)
5447
count = arguments.get("count", 5)

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,7 @@ def main(
4545
app = Server("mcp-streamable-http-demo")
4646

4747
@app.call_tool()
48-
async def call_tool(
49-
name: str, arguments: dict
50-
) -> list[
51-
types.TextContent
52-
| types.ImageContent
53-
| types.AudioContent
54-
| types.EmbeddedResource
55-
]:
48+
async def call_tool(name: str, arguments: dict) -> list[types.Content]:
5649
ctx = app.request_context
5750
interval = arguments.get("interval", 1.0)
5851
count = arguments.get("count", 5)

examples/servers/simple-tool/mcp_simple_tool/server.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
async def fetch_website(
99
url: str,
10-
) -> list[
11-
types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource
12-
]:
10+
) -> list[types.Content]:
1311
headers = {
1412
"User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)"
1513
}
@@ -31,14 +29,7 @@ def main(port: int, transport: str) -> int:
3129
app = Server("mcp-website-fetcher")
3230

3331
@app.call_tool()
34-
async def fetch_tool(
35-
name: str, arguments: dict
36-
) -> list[
37-
types.TextContent
38-
| types.ImageContent
39-
| types.AudioContent
40-
| types.EmbeddedResource
41-
]:
32+
async def fetch_tool(name: str, arguments: dict) -> list[types.Content]:
4233
if name != "fetch":
4334
raise ValueError(f"Unknown tool: {name}")
4435
if "url" not in arguments:

src/mcp/client/sse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ async def sse_client(
5454
async with anyio.create_task_group() as tg:
5555
try:
5656
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
57-
async with httpx_client_factory(headers=headers, auth=auth) as client:
57+
async with httpx_client_factory(
58+
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
59+
) as client:
5860
async with aconnect_sse(
5961
client,
6062
"GET",
6163
url,
62-
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
6364
) as event_source:
6465
event_source.response.raise_for_status()
6566
logger.debug("SSE connection established")

src/mcp/client/streamable_http.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass
1313
from datetime import timedelta
14-
from typing import Any
1514

1615
import anyio
1716
import httpx
@@ -52,14 +51,10 @@
5251
class StreamableHTTPError(Exception):
5352
"""Base exception for StreamableHTTP transport errors."""
5453

55-
pass
56-
5754

5855
class ResumptionError(StreamableHTTPError):
5956
"""Raised when resumption request is invalid."""
6057

61-
pass
62-
6358

6459
@dataclass
6560
class RequestContext:
@@ -71,7 +66,7 @@ class RequestContext:
7166
session_message: SessionMessage
7267
metadata: ClientMessageMetadata | None
7368
read_stream_writer: StreamWriter
74-
sse_read_timeout: timedelta
69+
sse_read_timeout: float
7570

7671

7772
class StreamableHTTPTransport:
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
8075
def __init__(
8176
self,
8277
url: str,
83-
headers: dict[str, Any] | None = None,
84-
timeout: timedelta = timedelta(seconds=30),
85-
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
78+
headers: dict[str, str] | None = None,
79+
timeout: float | timedelta = 30,
80+
sse_read_timeout: float | timedelta = 60 * 5,
8681
auth: httpx.Auth | None = None,
8782
) -> None:
8883
"""Initialize the StreamableHTTP transport.
@@ -96,10 +91,12 @@ def __init__(
9691
"""
9792
self.url = url
9893
self.headers = headers or {}
99-
self.timeout = timeout
100-
self.sse_read_timeout = sse_read_timeout
94+
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
95+
self.sse_read_timeout = (
96+
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
97+
)
10198
self.auth = auth
102-
self.session_id: str | None = None
99+
self.session_id = None
103100
self.request_headers = {
104101
ACCEPT: f"{JSON}, {SSE}",
105102
CONTENT_TYPE: JSON,
@@ -160,7 +157,7 @@ async def _handle_sse_event(
160157
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
161158

162159
except Exception as exc:
163-
logger.error(f"Error parsing SSE message: {exc}")
160+
logger.exception("Error parsing SSE message")
164161
await read_stream_writer.send(exc)
165162
return False
166163
else:
@@ -184,10 +181,7 @@ async def handle_get_stream(
184181
"GET",
185182
self.url,
186183
headers=headers,
187-
timeout=httpx.Timeout(
188-
self.timeout.total_seconds(),
189-
read=self.sse_read_timeout.total_seconds(),
190-
),
184+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
191185
) as event_source:
192186
event_source.response.raise_for_status()
193187
logger.debug("GET SSE connection established")
@@ -216,10 +210,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
216210
"GET",
217211
self.url,
218212
headers=headers,
219-
timeout=httpx.Timeout(
220-
self.timeout.total_seconds(),
221-
read=ctx.sse_read_timeout.total_seconds(),
222-
),
213+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
223214
) as event_source:
224215
event_source.response.raise_for_status()
225216
logger.debug("Resumption GET SSE connection established")
@@ -412,9 +403,9 @@ def get_session_id(self) -> str | None:
412403
@asynccontextmanager
413404
async def streamablehttp_client(
414405
url: str,
415-
headers: dict[str, Any] | None = None,
416-
timeout: timedelta = timedelta(seconds=30),
417-
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
406+
headers: dict[str, str] | None = None,
407+
timeout: float | timedelta = 30,
408+
sse_read_timeout: float | timedelta = 60 * 5,
418409
terminate_on_close: bool = True,
419410
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
420411
auth: httpx.Auth | None = None,
@@ -449,10 +440,7 @@ async def streamablehttp_client(
449440

450441
async with httpx_client_factory(
451442
headers=transport.request_headers,
452-
timeout=httpx.Timeout(
453-
transport.timeout.total_seconds(),
454-
read=transport.sse_read_timeout.total_seconds(),
455-
),
443+
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
456444
auth=transport.auth,
457445
) as client:
458446
# Define callbacks that need access to tg

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,16 @@
77
import pydantic_core
88
from pydantic import BaseModel, Field, TypeAdapter, validate_call
99

10-
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
11-
12-
CONTENT_TYPES = TextContent | ImageContent | AudioContent | EmbeddedResource
10+
from mcp.types import Content, TextContent
1311

1412

1513
class Message(BaseModel):
1614
"""Base class for all prompt messages."""
1715

1816
role: Literal["user", "assistant"]
19-
content: CONTENT_TYPES
17+
content: Content
2018

21-
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
19+
def __init__(self, content: str | Content, **kwargs: Any):
2220
if isinstance(content, str):
2321
content = TextContent(type="text", text=content)
2422
super().__init__(content=content, **kwargs)
@@ -29,7 +27,7 @@ class UserMessage(Message):
2927

3028
role: Literal["user", "assistant"] = "user"
3129

32-
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
30+
def __init__(self, content: str | Content, **kwargs: Any):
3331
super().__init__(content=content, **kwargs)
3432

3533

@@ -38,7 +36,7 @@ class AssistantMessage(Message):
3836

3937
role: Literal["user", "assistant"] = "assistant"
4038

41-
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
39+
def __init__(self, content: str | Content, **kwargs: Any):
4240
super().__init__(content=content, **kwargs)
4341

4442

src/mcp/server/fastmcp/server.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@
5252
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
5353
from mcp.types import (
5454
AnyFunction,
55-
AudioContent,
56-
EmbeddedResource,
55+
Content,
5756
GetPromptResult,
58-
ImageContent,
5957
TextContent,
6058
ToolAnnotations,
6159
)
@@ -258,9 +256,7 @@ def get_context(self) -> Context[ServerSession, object, Request]:
258256
request_context = None
259257
return Context(request_context=request_context, fastmcp=self)
260258

261-
async def call_tool(
262-
self, name: str, arguments: dict[str, Any]
263-
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
259+
async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[Content]:
264260
"""Call a tool by name with arguments."""
265261
context = self.get_context()
266262
result = await self._tool_manager.call_tool(name, arguments, context=context)
@@ -844,12 +840,12 @@ async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -
844840

845841
def _convert_to_content(
846842
result: Any,
847-
) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
843+
) -> Sequence[Content]:
848844
"""Convert a result to a sequence of content objects."""
849845
if result is None:
850846
return []
851847

852-
if isinstance(result, TextContent | ImageContent | AudioContent | EmbeddedResource):
848+
if isinstance(result, Content):
853849
return [result]
854850

855851
if isinstance(result, Image):

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,9 +384,7 @@ def call_tool(self):
384384
def decorator(
385385
func: Callable[
386386
...,
387-
Awaitable[
388-
Iterable[types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource]
389-
],
387+
Awaitable[Iterable[types.Content]],
390388
],
391389
):
392390
logger.debug("Registering handler for CallToolRequest")

src/mcp/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,11 +675,14 @@ class EmbeddedResource(BaseModel):
675675
model_config = ConfigDict(extra="allow")
676676

677677

678+
Content = TextContent | ImageContent | AudioContent | EmbeddedResource
679+
680+
678681
class PromptMessage(BaseModel):
679682
"""Describes a message returned as part of a prompt."""
680683

681684
role: Role
682-
content: TextContent | ImageContent | AudioContent | EmbeddedResource
685+
content: Content
683686
model_config = ConfigDict(extra="allow")
684687

685688

@@ -795,7 +798,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
795798
class CallToolResult(Result):
796799
"""The server's response to a tool call."""
797800

798-
content: list[TextContent | ImageContent | AudioContent | EmbeddedResource]
801+
content: list[Content]
799802
isError: bool = False
800803

801804

tests/issues/test_88_random_error.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from mcp.client.session import ClientSession
1212
from mcp.server.lowlevel import Server
1313
from mcp.shared.exceptions import McpError
14-
from mcp.types import AudioContent, EmbeddedResource, ImageContent, TextContent
14+
from mcp.types import Content, TextContent
1515

1616

1717
@pytest.mark.anyio
@@ -31,7 +31,7 @@ async def test_notification_validation_error(tmp_path: Path):
3131
slow_request_complete = anyio.Event()
3232

3333
@server.call_tool()
34-
async def slow_tool(name: str, arg) -> Sequence[TextContent | ImageContent | AudioContent | EmbeddedResource]:
34+
async def slow_tool(name: str, arg) -> Sequence[Content]:
3535
nonlocal request_count
3636
request_count += 1
3737

0 commit comments

Comments
 (0)