Skip to content

Commit 3dede75

Browse files
committed
updated sse with recent changes
1 parent 1d6bf07 commit 3dede75

File tree

1 file changed

+114
-101
lines changed

1 file changed

+114
-101
lines changed

src/mcp/client/sse.py

Lines changed: 114 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import httpx
88
from anyio.abc import TaskStatus
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from exceptiongroup import BaseExceptionGroup, catch
1011
from httpx_sse import aconnect_sse
1112

1213
import mcp.types as types
@@ -19,6 +20,12 @@ def remove_request_params(url: str) -> str:
1920
return urljoin(url, urlparse(url).path)
2021

2122

23+
def handle_exception(exc: BaseExceptionGroup[Exception]) -> str:
24+
"""Handle ExceptionGroup and Exceptions for Client transport for SSE"""
25+
messages = "; ".join(str(e) for e in exc.exceptions)
26+
raise Exception(f"TaskGroup failed with: {messages}") from None
27+
28+
2229
@asynccontextmanager
2330
async def sse_client(
2431
url: str,
@@ -41,114 +48,120 @@ async def sse_client(
4148
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
4249
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
4350

44-
async with anyio.create_task_group() as tg:
45-
try:
46-
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
47-
async with httpx.AsyncClient(headers=headers) as client:
48-
async with aconnect_sse(
49-
client,
50-
"GET",
51-
url,
52-
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
53-
) as event_source:
54-
event_source.response.raise_for_status()
55-
logger.debug("SSE connection established")
56-
57-
async def sse_reader(
58-
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
59-
):
60-
try:
61-
async for sse in event_source.aiter_sse():
62-
logger.debug(f"Received SSE event: {sse.event}")
63-
match sse.event:
64-
case "endpoint":
65-
endpoint_url = urljoin(url, sse.data)
66-
logger.info(
67-
f"Received endpoint URL: {endpoint_url}"
68-
)
69-
70-
url_parsed = urlparse(url)
71-
endpoint_parsed = urlparse(endpoint_url)
72-
if (
73-
url_parsed.netloc != endpoint_parsed.netloc
74-
or url_parsed.scheme
75-
!= endpoint_parsed.scheme
76-
):
77-
error_msg = (
78-
"Endpoint origin does not match "
79-
f"connection origin: {endpoint_url}"
51+
with catch({Exception: handle_exception}):
52+
async with anyio.create_task_group() as tg:
53+
try:
54+
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
55+
async with httpx.AsyncClient(headers=headers) as client:
56+
async with aconnect_sse(
57+
client,
58+
"GET",
59+
url,
60+
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
61+
) as event_source:
62+
event_source.response.raise_for_status()
63+
logger.debug("SSE connection established")
64+
65+
async def sse_reader(
66+
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
67+
):
68+
try:
69+
async for sse in event_source.aiter_sse():
70+
logger.debug(f"Received SSE event: {sse.event}")
71+
match sse.event:
72+
case "endpoint":
73+
endpoint_url = urljoin(url, sse.data)
74+
logger.info(
75+
f"Received endpoint URL: {endpoint_url}"
8076
)
81-
logger.error(error_msg)
82-
raise ValueError(error_msg)
8377

84-
task_status.started(endpoint_url)
85-
86-
case "message":
87-
try:
88-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
89-
sse.data
78+
url_parsed = urlparse(url)
79+
endpoint_parsed = urlparse(endpoint_url)
80+
if (
81+
url_parsed.netloc
82+
!= endpoint_parsed.netloc
83+
or url_parsed.scheme
84+
!= endpoint_parsed.scheme
85+
):
86+
error_msg = (
87+
"Endpoint origin does not match "
88+
f"connection origin: {endpoint_url}"
89+
)
90+
logger.error(error_msg)
91+
raise ValueError(error_msg)
92+
93+
task_status.started(endpoint_url)
94+
95+
case "message":
96+
try:
97+
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
98+
sse.data
99+
)
100+
logger.debug(
101+
"Received server message: "
102+
f"{message}"
103+
)
104+
except Exception as exc:
105+
logger.error(
106+
"Error parsing server message: "
107+
f"{exc}"
108+
)
109+
await read_stream_writer.send(exc)
110+
continue
111+
112+
session_message = SessionMessage(
113+
message=message
90114
)
91-
logger.debug(
92-
f"Received server message: {message}"
115+
await read_stream_writer.send(
116+
session_message
93117
)
94-
except Exception as exc:
95-
logger.error(
96-
f"Error parsing server message: {exc}"
118+
case _:
119+
logger.warning(
120+
f"Unknown SSE event: {sse.event}"
97121
)
98-
await read_stream_writer.send(exc)
99-
continue
100-
101-
session_message = SessionMessage(
102-
message=message
122+
except Exception as exc:
123+
logger.error(f"Error in sse_reader: {exc}")
124+
await read_stream_writer.send(exc)
125+
finally:
126+
await read_stream_writer.aclose()
127+
128+
async def post_writer(endpoint_url: str):
129+
try:
130+
async with write_stream_reader:
131+
async for session_message in write_stream_reader:
132+
logger.debug(
133+
f"Sending client message: {session_message}"
103134
)
104-
await read_stream_writer.send(session_message)
105-
case _:
106-
logger.warning(
107-
f"Unknown SSE event: {sse.event}"
135+
response = await client.post(
136+
endpoint_url,
137+
json=session_message.message.model_dump(
138+
by_alias=True,
139+
mode="json",
140+
exclude_none=True,
141+
),
108142
)
109-
except Exception as exc:
110-
logger.error(f"Error in sse_reader: {exc}")
111-
await read_stream_writer.send(exc)
112-
finally:
113-
await read_stream_writer.aclose()
143+
response.raise_for_status()
144+
logger.debug(
145+
"Client message sent successfully: "
146+
f"{response.status_code}"
147+
)
148+
except Exception as exc:
149+
logger.error(f"Error in post_writer: {exc}")
150+
finally:
151+
await write_stream.aclose()
152+
153+
endpoint_url = await tg.start(sse_reader)
154+
logger.info(
155+
f"Starting post writer with endpoint URL: {endpoint_url}"
156+
)
157+
tg.start_soon(post_writer, endpoint_url)
114158

115-
async def post_writer(endpoint_url: str):
116159
try:
117-
async with write_stream_reader:
118-
async for session_message in write_stream_reader:
119-
logger.debug(
120-
f"Sending client message: {session_message}"
121-
)
122-
response = await client.post(
123-
endpoint_url,
124-
json=session_message.message.model_dump(
125-
by_alias=True,
126-
mode="json",
127-
exclude_none=True,
128-
),
129-
)
130-
response.raise_for_status()
131-
logger.debug(
132-
"Client message sent successfully: "
133-
f"{response.status_code}"
134-
)
135-
except Exception as exc:
136-
logger.error(f"Error in post_writer: {exc}")
160+
yield read_stream, write_stream
137161
finally:
138-
await write_stream.aclose()
139-
140-
endpoint_url = await tg.start(sse_reader)
141-
logger.info(
142-
f"Starting post writer with endpoint URL: {endpoint_url}"
143-
)
144-
tg.start_soon(post_writer, endpoint_url)
145-
146-
try:
147-
yield read_stream, write_stream
148-
finally:
149-
tg.cancel_scope.cancel()
150-
finally:
151-
await read_stream_writer.aclose()
152-
await write_stream.aclose()
153-
await read_stream.aclose()
154-
await write_stream_reader.aclose()
162+
tg.cancel_scope.cancel()
163+
finally:
164+
await read_stream_writer.aclose()
165+
await write_stream.aclose()
166+
await read_stream.aclose()
167+
await write_stream_reader.aclose()

0 commit comments

Comments
 (0)