Skip to content

Commit 1d6bf07

Browse files
committed
Merge branch 'main' into fix-sse-client-blocks-indefinitely-when-server-has-incorrect-base-url
2 parents 11c7ced + 3b1b213 commit 1d6bf07

File tree

26 files changed

+1344
-160
lines changed

26 files changed

+1344
-160
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,30 @@ app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app()))
412412

413413
For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes).
414414

415+
#### Message Dispatch Options
416+
417+
By default, the SSE server uses an in-memory message dispatch system for incoming POST messages. For production deployments or distributed scenarios, you can use Redis or implement your own message dispatch system that conforms to the `MessageDispatch` protocol:
418+
419+
```python
420+
# Using the built-in Redis message dispatch
421+
from mcp.server.fastmcp import FastMCP
422+
from mcp.server.message_queue import RedisMessageDispatch
423+
424+
# Create a Redis message dispatch
425+
redis_dispatch = RedisMessageDispatch(
426+
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
427+
)
428+
429+
# Pass the message dispatch instance to the server
430+
mcp = FastMCP("My App", message_queue=redis_dispatch)
431+
```
432+
433+
To use Redis, add the Redis dependency:
434+
435+
```bash
436+
uv add "mcp[redis]"
437+
```
438+
415439
## Examples
416440

417441
### Echo Server

examples/servers/simple-prompt/mcp_simple_prompt/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ async def get_prompt(
8888
)
8989

9090
if transport == "sse":
91+
from mcp.server.message_queue.redis import RedisMessageDispatch
9192
from mcp.server.sse import SseServerTransport
9293
from starlette.applications import Starlette
9394
from starlette.responses import Response
9495
from starlette.routing import Mount, Route
9596

96-
sse = SseServerTransport("/messages/")
97+
message_dispatch = RedisMessageDispatch("redis://localhost:6379/0")
98+
99+
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
97100

98101
async def handle_sse(request):
99102
async with sse.connect_sse(

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ dependencies = [
3838
rich = ["rich>=13.9.4"]
3939
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
4040
ws = ["websockets>=15.0.1"]
41+
redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"]
4142

4243
[project.scripts]
4344
mcp = "mcp.cli:app [cli]"
@@ -56,6 +57,7 @@ dev = [
5657
"pytest-xdist>=3.6.1",
5758
"pytest-examples>=0.0.14",
5859
"pytest-pretty>=1.2.0",
60+
"fakeredis==2.28.1",
5961
]
6062
docs = [
6163
"mkdocs>=1.6.1",

src/mcp/client/sse.py

Lines changed: 102 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import httpx
88
from anyio.abc import TaskStatus
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10-
from exceptiongroup import BaseExceptionGroup, catch
1110
from httpx_sse import aconnect_sse
1211

1312
import mcp.types as types
@@ -20,11 +19,6 @@ def remove_request_params(url: str) -> str:
2019
return urljoin(url, urlparse(url).path)
2120

2221

23-
def handle_exception(exc: BaseExceptionGroup[Exception]) -> str:
24-
"""Handle ExceptionGroup and Exceptions for Client transport for SSE"""
25-
messages = "; ".join(str(e) for e in exc.exceptions)
26-
raise Exception(f"TaskGroup failed with: {messages}") from None
27-
2822
@asynccontextmanager
2923
async def sse_client(
3024
url: str,
@@ -47,117 +41,114 @@ async def sse_client(
4741
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4842
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
4943

50-
with catch({Exception: handle_exception}):
51-
async with anyio.create_task_group() as tg:
52-
try:
53-
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
54-
async with httpx.AsyncClient(headers=headers) as client:
55-
async with aconnect_sse(
56-
client,
57-
"GET",
58-
url,
59-
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
60-
) as event_source:
61-
event_source.response.raise_for_status()
62-
logger.debug("SSE connection established")
63-
64-
async def sse_reader(
65-
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
66-
):
67-
try:
68-
async for sse in event_source.aiter_sse():
69-
logger.debug(f"Received SSE event: {sse.event}")
70-
match sse.event:
71-
case "endpoint":
72-
endpoint_url = urljoin(url, sse.data)
73-
logger.info(
74-
f"Received endpoint URL: {endpoint_url}"
44+
async with anyio.create_task_group() as tg:
45+
try:
46+
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
47+
async with httpx.AsyncClient(headers=headers) as client:
48+
async with aconnect_sse(
49+
client,
50+
"GET",
51+
url,
52+
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
53+
) as event_source:
54+
event_source.response.raise_for_status()
55+
logger.debug("SSE connection established")
56+
57+
async def sse_reader(
58+
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
59+
):
60+
try:
61+
async for sse in event_source.aiter_sse():
62+
logger.debug(f"Received SSE event: {sse.event}")
63+
match sse.event:
64+
case "endpoint":
65+
endpoint_url = urljoin(url, sse.data)
66+
logger.info(
67+
f"Received endpoint URL: {endpoint_url}"
68+
)
69+
70+
url_parsed = urlparse(url)
71+
endpoint_parsed = urlparse(endpoint_url)
72+
if (
73+
url_parsed.netloc != endpoint_parsed.netloc
74+
or url_parsed.scheme
75+
!= endpoint_parsed.scheme
76+
):
77+
error_msg = (
78+
"Endpoint origin does not match "
79+
f"connection origin: {endpoint_url}"
7580
)
81+
logger.error(error_msg)
82+
raise ValueError(error_msg)
83+
84+
task_status.started(endpoint_url)
7685

77-
url_parsed = urlparse(url)
78-
endpoint_parsed = urlparse(endpoint_url)
79-
if (
80-
url_parsed.netloc
81-
!= endpoint_parsed.netloc
82-
or url_parsed.scheme
83-
!= endpoint_parsed.scheme
84-
):
85-
error_msg = (
86-
"Endpoint origin does not match "
87-
f"connection origin: {endpoint_url}"
88-
)
89-
logger.error(error_msg)
90-
raise ValueError(error_msg)
91-
92-
task_status.started(endpoint_url)
93-
94-
case "message":
95-
try:
96-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
97-
sse.data
98-
)
99-
logger.debug(
100-
"Received server message: "
101-
f"{message}"
102-
103-
)
104-
except Exception as exc:
105-
logger.error(
106-
"Error parsing server message: "
107-
f"{exc}"
108-
)
109-
await read_stream_writer.send(exc)
110-
continue
111-
112-
session_message = SessionMessage(message)
113-
await read_stream_writer.send(
114-
session_message
86+
case "message":
87+
try:
88+
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
89+
sse.data
11590
)
116-
case _:
117-
logger.warning(
118-
f"Unknown SSE event: {sse.event}"
91+
logger.debug(
92+
f"Received server message: {message}"
11993
)
120-
except Exception as exc:
121-
logger.error(f"Error in sse_reader: {exc}")
122-
await read_stream_writer.send(exc)
123-
finally:
124-
await read_stream_writer.aclose()
125-
126-
async def post_writer(endpoint_url: str):
127-
try:
128-
async with write_stream_reader:
129-
async for session_message in write_stream_reader:
130-
logger.debug(
131-
f"Sending client message: {session_message}"
132-
)
133-
response = await client.post(
134-
endpoint_url,
135-
json=session_message.message.model_dump(
136-
by_alias=True,
137-
mode="json",
138-
exclude_none=True,
139-
),
94+
except Exception as exc:
95+
logger.error(
96+
f"Error parsing server message: {exc}"
97+
)
98+
await read_stream_writer.send(exc)
99+
continue
100+
101+
session_message = SessionMessage(
102+
message=message
140103
)
141-
response.raise_for_status()
142-
logger.debug(
143-
"Client message sent successfully: "
144-
f"{response.status_code}"
104+
await read_stream_writer.send(session_message)
105+
case _:
106+
logger.warning(
107+
f"Unknown SSE event: {sse.event}"
145108
)
146-
except Exception as exc:
147-
logger.error(f"Error in post_writer: {exc}")
148-
finally:
149-
await write_stream.aclose()
150-
151-
endpoint_url = await tg.start(sse_reader)
152-
logger.info(
153-
f"Starting post writer with endpoint URL: {endpoint_url}"
154-
)
155-
tg.start_soon(post_writer, endpoint_url)
109+
except Exception as exc:
110+
logger.error(f"Error in sse_reader: {exc}")
111+
await read_stream_writer.send(exc)
112+
finally:
113+
await read_stream_writer.aclose()
156114

115+
async def post_writer(endpoint_url: str):
157116
try:
158-
yield read_stream, write_stream
117+
async with write_stream_reader:
118+
async for session_message in write_stream_reader:
119+
logger.debug(
120+
f"Sending client message: {session_message}"
121+
)
122+
response = await client.post(
123+
endpoint_url,
124+
json=session_message.message.model_dump(
125+
by_alias=True,
126+
mode="json",
127+
exclude_none=True,
128+
),
129+
)
130+
response.raise_for_status()
131+
logger.debug(
132+
"Client message sent successfully: "
133+
f"{response.status_code}"
134+
)
135+
except Exception as exc:
136+
logger.error(f"Error in post_writer: {exc}")
159137
finally:
160-
tg.cancel_scope.cancel()
161-
finally:
162-
await read_stream_writer.aclose()
163-
await write_stream.aclose()
138+
await write_stream.aclose()
139+
140+
endpoint_url = await tg.start(sse_reader)
141+
logger.info(
142+
f"Starting post writer with endpoint URL: {endpoint_url}"
143+
)
144+
tg.start_soon(post_writer, endpoint_url)
145+
146+
try:
147+
yield read_stream, write_stream
148+
finally:
149+
tg.cancel_scope.cancel()
150+
finally:
151+
await read_stream_writer.aclose()
152+
await write_stream.aclose()
153+
await read_stream.aclose()
154+
await write_stream_reader.aclose()

src/mcp/client/stdio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ async def stdout_reader():
144144
await read_stream_writer.send(exc)
145145
continue
146146

147-
session_message = SessionMessage(message)
147+
session_message = SessionMessage(message=message)
148148
await read_stream_writer.send(session_message)
149149
except anyio.ClosedResourceError:
150150
await anyio.lowlevel.checkpoint()

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def _handle_sse_event(
153153
):
154154
message.root.id = original_request_id
155155

156-
session_message = SessionMessage(message)
156+
session_message = SessionMessage(message=message)
157157
await read_stream_writer.send(session_message)
158158

159159
# Call resumption token callback if we have an ID
@@ -286,7 +286,7 @@ async def _handle_json_response(
286286
try:
287287
content = await response.aread()
288288
message = JSONRPCMessage.model_validate_json(content)
289-
session_message = SessionMessage(message)
289+
session_message = SessionMessage(message=message)
290290
await read_stream_writer.send(session_message)
291291
except Exception as exc:
292292
logger.error(f"Error parsing JSON response: {exc}")
@@ -333,7 +333,7 @@ async def _send_session_terminated_error(
333333
id=request_id,
334334
error=ErrorData(code=32600, message="Session terminated"),
335335
)
336-
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
336+
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
337337
await read_stream_writer.send(session_message)
338338

339339
async def post_writer(

src/mcp/client/websocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def ws_reader():
6060
async for raw_text in ws:
6161
try:
6262
message = types.JSONRPCMessage.model_validate_json(raw_text)
63-
session_message = SessionMessage(message)
63+
session_message = SessionMessage(message=message)
6464
await read_stream_writer.send(session_message)
6565
except ValidationError as exc:
6666
# If JSON parse or model validation fails, send the exception

0 commit comments

Comments
 (0)