Skip to content

Commit cae32e2

Browse files
committed
Merge branch 'ihrpr/streamablehttp-server' into ihrpr/streamablehttp-stateless
2 parents 1902abb + 2e615f3 commit cae32e2

File tree

10 files changed

+61
-27
lines changed

10 files changed

+61
-27
lines changed

examples/servers/simple-streamablehttp/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en
55
## Features
66

77
- Uses the StreamableHTTP transport for server-client communication
8+
- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint
89
- Task management with anyio task groups
910
- Ability to send multiple notifications over time to the client
1011
- Proper resource cleanup and lifespan management

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ async def call_tool(
8282
level="info",
8383
data=f"Notification {i+1}/{count} from caller: {caller}",
8484
logger="notification_stream",
85+
# Associates this notification with the original request
86+
# Ensures notifications are sent to the correct response stream
87+
# Without this, notifications will either go to:
88+
# - a standalone SSE stream (if GET request is supported)
89+
# - nowhere (if GET request isn't supported)
8590
related_request_id=ctx.request_id,
8691
)
8792
if i < count - 1: # Don't wait after the last notification

src/mcp/server/fastmcp/server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@ async def run_stdio_async(self) -> None:
466466
async def run_sse_async(self) -> None:
467467
"""Run the server using SSE transport."""
468468
import uvicorn
469+
469470
starlette_app = self.sse_app()
470471

471472
config = uvicorn.Config(
@@ -673,7 +674,10 @@ async def log(
673674
**extra: Additional structured data to include
674675
"""
675676
await self.request_context.session.send_log_message(
676-
level=level, data=message, logger=logger_name
677+
level=level,
678+
data=message,
679+
logger=logger_name,
680+
related_request_id=self.request_id,
677681
)
678682

679683
@property

src/mcp/server/streamableHttp.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -577,11 +577,6 @@ async def connect(
577577
"""
578578

579579
# Create the memory streams for this connection
580-
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
581-
read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception]
582-
583-
write_stream: MemoryObjectSendStream[JSONRPCMessage]
584-
write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage]
585580

586581
read_stream_writer, read_stream = anyio.create_memory_object_stream[
587582
JSONRPCMessage | Exception
@@ -605,10 +600,13 @@ async def message_router():
605600
if isinstance(
606601
message.root, JSONRPCNotification | JSONRPCRequest
607602
):
608-
# Extract related_request_id from params if it exists
609-
if (params := getattr(message.root, "params", None)) and (
610-
related_id := params.get("related_request_id")
611-
) is not None:
603+
# Extract related_request_id from meta if it exists
604+
if (
605+
(params := getattr(message.root, "params", None))
606+
and (meta := params.get("_meta"))
607+
and (related_id := meta.get("related_request_id"))
608+
is not None
609+
):
612610
target_request_id = str(related_id)
613611
else:
614612
target_request_id = str(message.root.id)

src/mcp/shared/session.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Generic, TypeVar
77

88
import anyio
9-
import anyio.lowlevel
109
import httpx
1110
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1211
from pydantic import BaseModel
@@ -24,6 +23,7 @@
2423
JSONRPCNotification,
2524
JSONRPCRequest,
2625
JSONRPCResponse,
26+
NotificationParams,
2727
RequestParams,
2828
ServerNotification,
2929
ServerRequest,
@@ -276,8 +276,19 @@ async def send_notification(
276276
Emits a notification, which is a one-way message that does not expect
277277
a response.
278278
"""
279+
# Some transport implementations may need to set the related_request_id
280+
# to attribute to the notifications to the request that triggered them.
279281
if related_request_id is not None and notification.root.params is not None:
280-
notification.root.params.related_request_id = related_request_id
282+
# Create meta if it doesn't exist
283+
if notification.root.params.meta is None:
284+
meta_dict = {"related_request_id": related_request_id}
285+
286+
else:
287+
meta_dict = notification.root.params.meta.model_dump(
288+
by_alias=True, mode="json", exclude_none=True
289+
)
290+
meta_dict["related_request_id"] = related_request_id
291+
notification.root.params.meta = NotificationParams.Meta(**meta_dict)
281292
jsonrpc_notification = JSONRPCNotification(
282293
jsonrpc="2.0",
283294
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),

src/mcp/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class Meta(BaseModel):
5858
model_config = ConfigDict(extra="allow")
5959

6060
meta: Meta | None = Field(alias="_meta", default=None)
61-
related_request_id: RequestId | None = None
6261
"""
6362
This parameter name is reserved by MCP to allow clients and servers to attach
6463
additional metadata to their notifications.

tests/client/test_logging_callback.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mcp.shared.session import RequestResponder
1010
from mcp.types import (
1111
LoggingMessageNotificationParams,
12+
NotificationParams,
1213
TextContent,
1314
)
1415

@@ -78,6 +79,11 @@ async def message_handler(
7879
)
7980
assert log_result.isError is False
8081
assert len(logging_collector.log_messages) == 1
81-
assert logging_collector.log_messages[0] == LoggingMessageNotificationParams(
82-
level="info", logger="test_logger", data="Test log message"
83-
)
82+
# Create meta object with related_request_id added dynamically
83+
meta = NotificationParams.Meta()
84+
setattr(meta, "related_request_id", "2")
85+
log = logging_collector.log_messages[0]
86+
assert log.level == "info"
87+
assert log.logger == "test_logger"
88+
assert log.data == "Test log message"
89+
assert log.meta == meta

tests/issues/test_188_concurrency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def slow_resource():
3535
end_time = anyio.current_time()
3636

3737
duration = end_time - start_time
38-
assert duration < 3 * _sleep_time_seconds
38+
assert duration < 6 * _sleep_time_seconds
3939
print(duration)
4040

4141

tests/server/fastmcp/servers/test_file_server.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,13 @@ async def test_read_resource_file(mcp: FastMCP):
114114

115115
@pytest.mark.anyio
116116
async def test_delete_file(mcp: FastMCP, test_dir: Path):
117-
await mcp.call_tool(
118-
"delete_file", arguments={"path": str(test_dir / "example.py")}
119-
)
117+
await mcp.call_tool("delete_file", arguments={"path": str(test_dir / "example.py")})
120118
assert not (test_dir / "example.py").exists()
121119

122120

123121
@pytest.mark.anyio
124122
async def test_delete_file_and_check_resources(mcp: FastMCP, test_dir: Path):
125-
await mcp.call_tool(
126-
"delete_file", arguments={"path": str(test_dir / "example.py")}
127-
)
123+
await mcp.call_tool("delete_file", arguments={"path": str(test_dir / "example.py")})
128124
res_iter = await mcp.read_resource("file://test_dir/example.py")
129125
res_list = list(res_iter)
130126
assert len(res_list) == 1

tests/server/fastmcp/test_server.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -544,14 +544,28 @@ async def logging_tool(msg: str, ctx: Context) -> str:
544544

545545
assert mock_log.call_count == 4
546546
mock_log.assert_any_call(
547-
level="debug", data="Debug message", logger=None
547+
level="debug",
548+
data="Debug message",
549+
logger=None,
550+
related_request_id="1",
548551
)
549-
mock_log.assert_any_call(level="info", data="Info message", logger=None)
550552
mock_log.assert_any_call(
551-
level="warning", data="Warning message", logger=None
553+
level="info",
554+
data="Info message",
555+
logger=None,
556+
related_request_id="1",
552557
)
553558
mock_log.assert_any_call(
554-
level="error", data="Error message", logger=None
559+
level="warning",
560+
data="Warning message",
561+
logger=None,
562+
related_request_id="1",
563+
)
564+
mock_log.assert_any_call(
565+
level="error",
566+
data="Error message",
567+
logger=None,
568+
related_request_id="1",
555569
)
556570

557571
@pytest.mark.anyio

0 commit comments

Comments
 (0)