Skip to content

Commit 25600d4

Browse files
committed
streamable
1 parent c2ca8e0 commit 25600d4

File tree

2 files changed

+183
-1
lines changed

2 files changed

+183
-1
lines changed

src/mcp/client/session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
list_roots_callback: ListRootsFnT | None = None,
9898
logging_callback: LoggingFnT | None = None,
9999
message_handler: MessageHandlerFnT | None = None,
100+
supported_protocol_versions: tuple[str | int, ...] | None = None,
100101
) -> None:
101102
super().__init__(
102103
read_stream,
@@ -109,6 +110,9 @@ def __init__(
109110
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
110111
self._logging_callback = logging_callback or _default_logging_callback
111112
self._message_handler = message_handler or _default_message_handler
113+
self._supported_protocol_versions = (
114+
supported_protocol_versions or SUPPORTED_PROTOCOL_VERSIONS
115+
)
112116

113117
async def initialize(self) -> types.InitializeResult:
114118
sampling = types.SamplingCapability()
@@ -137,7 +141,7 @@ async def initialize(self) -> types.InitializeResult:
137141
types.InitializeResult,
138142
)
139143

140-
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
144+
if result.protocolVersion not in self._supported_protocol_versions:
141145
raise RuntimeError(
142146
"Unsupported protocol version from the server: "
143147
f"{result.protocolVersion}"

src/mcp/client/streamable.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import logging
2+
from contextlib import asynccontextmanager
3+
4+
import anyio
5+
import httpx
6+
from httpx_sse import EventSource
7+
from pydantic import TypeAdapter
8+
9+
import mcp.types as types
10+
from mcp.client.sse import sse_client
11+
12+
logger = logging.getLogger(__name__)
13+
14+
STREAMABLE_PROTOCOL_VERSION = "2025-03-26"
15+
SUPPORTED_PROTOCOL_VERSIONS: tuple[str, ...] = (
16+
types.LATEST_PROTOCOL_VERSION,
17+
STREAMABLE_PROTOCOL_VERSION,
18+
)
19+
20+
21+
@asynccontextmanager
22+
async def streamable_client(
23+
url: str,
24+
timeout: float = 5,
25+
):
26+
"""
27+
Client transport for streamable HTTP, with fallback to SSE.
28+
"""
29+
if await _is_old_sse_server(url, timeout):
30+
async with sse_client(url) as (read_stream, write_stream):
31+
yield read_stream, write_stream
32+
return
33+
34+
read_stream_writer, read_stream = anyio.create_memory_object_stream[
35+
types.JSONRPCMessage | Exception
36+
](0)
37+
write_stream, write_stream_reader = anyio.create_memory_object_stream[
38+
types.JSONRPCMessage
39+
](0)
40+
41+
async def handle_response(text: str) -> None:
42+
items = _maybe_list_adapter.validate_json(text)
43+
if isinstance(items, types.JSONRPCMessage):
44+
items = [items]
45+
for item in items:
46+
await read_stream_writer.send(item)
47+
48+
headers: tuple[tuple[str, str], ...] = ()
49+
50+
async with anyio.create_task_group() as tg:
51+
try:
52+
async with httpx.AsyncClient(timeout=timeout) as client:
53+
54+
async def sse_reader(event_source: EventSource):
55+
try:
56+
async for sse in event_source.aiter_sse():
57+
logger.debug(f"Received SSE event: {sse.event}")
58+
match sse.event:
59+
case "message":
60+
try:
61+
await handle_response(sse.data)
62+
logger.debug(
63+
f"Received server message: {sse.data}"
64+
)
65+
except Exception as exc:
66+
logger.error(
67+
f"Error parsing server message: {exc}"
68+
)
69+
await read_stream_writer.send(exc)
70+
continue
71+
case _:
72+
logger.warning(f"Unknown SSE event: {sse.event}")
73+
except Exception as exc:
74+
logger.error(f"Error in sse_reader: {exc}")
75+
await read_stream_writer.send(exc)
76+
finally:
77+
await read_stream_writer.aclose()
78+
79+
async def post_writer():
80+
nonlocal headers
81+
try:
82+
async with write_stream_reader:
83+
async for message in write_stream_reader:
84+
logger.debug(f"Sending client message: {message}")
85+
response = await client.post(
86+
url,
87+
json=message.model_dump(
88+
by_alias=True,
89+
mode="json",
90+
exclude_none=True,
91+
),
92+
headers=(
93+
("accept", "application/json"),
94+
("accept", "text/event-stream"),
95+
*headers,
96+
),
97+
)
98+
logger.debug(
99+
f"response {url=} content-type={response.headers.get("content-type")} body={response.text}"
100+
)
101+
102+
response.raise_for_status()
103+
match response.headers.get("mcp-session-id"):
104+
case str() as session_id:
105+
headers = (("mcp-session-id", session_id),)
106+
case _:
107+
pass
108+
109+
match response.headers.get("content-type"):
110+
case "text/event-stream":
111+
await sse_reader(EventSource(response))
112+
case "application/json":
113+
await handle_response(response.text)
114+
case None:
115+
pass
116+
case unknown:
117+
logger.warning(
118+
f"Unknown content type: {unknown}"
119+
)
120+
121+
logger.debug(
122+
"Client message sent successfully: "
123+
f"{response.status_code}"
124+
)
125+
except Exception as exc:
126+
logger.error(f"Error in post_writer: {exc}", exc_info=True)
127+
finally:
128+
await write_stream.aclose()
129+
130+
tg.start_soon(post_writer)
131+
132+
try:
133+
yield read_stream, write_stream
134+
finally:
135+
tg.cancel_scope.cancel()
136+
finally:
137+
await read_stream_writer.aclose()
138+
await write_stream.aclose()
139+
140+
141+
_maybe_list_adapter: TypeAdapter[types.JSONRPCMessage | list[types.JSONRPCMessage]] = (
142+
TypeAdapter(types.JSONRPCMessage | list[types.JSONRPCMessage])
143+
)
144+
145+
146+
async def _is_old_sse_server(url: str, timeout: float) -> bool:
147+
"""
148+
Test whether this is an old SSE MCP server.
149+
150+
See: https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/transports/#backwards-compatibility
151+
"""
152+
async with httpx.AsyncClient(timeout=timeout) as client:
153+
test_initialize_request = types.InitializeRequest(
154+
method="initialize",
155+
params=types.InitializeRequestParams(
156+
protocolVersion=STREAMABLE_PROTOCOL_VERSION,
157+
capabilities=types.ClientCapabilities(),
158+
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
159+
),
160+
)
161+
response = await client.post(
162+
url,
163+
json=types.JSONRPCRequest(
164+
jsonrpc="2.0",
165+
id=1,
166+
method=test_initialize_request.method,
167+
params=test_initialize_request.params.model_dump(
168+
by_alias=True, mode="json", exclude_none=True
169+
),
170+
).model_dump(by_alias=True, mode="json", exclude_none=True),
171+
headers=(
172+
("accept", "application/json"),
173+
("accept", "text/event-stream"),
174+
),
175+
)
176+
if 400 <= response.status_code < 500:
177+
return True
178+
return False

0 commit comments

Comments
 (0)