Skip to content

Commit 4d014cc

Browse files
committed
fix: yield placement issue
1 parent 3dede75 commit 4d014cc

File tree

1 file changed

+105
-107
lines changed

1 file changed

+105
-107
lines changed

src/mcp/client/sse.py

Lines changed: 105 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -49,119 +49,117 @@ async def sse_client(
4949
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
5050

5151
with catch({Exception: handle_exception}):
52-
async with anyio.create_task_group() as tg:
53-
try:
54-
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
55-
async with httpx.AsyncClient(headers=headers) as client:
56-
async with aconnect_sse(
57-
client,
58-
"GET",
59-
url,
60-
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
61-
) as event_source:
62-
event_source.response.raise_for_status()
63-
logger.debug("SSE connection established")
64-
65-
async def sse_reader(
66-
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
67-
):
68-
try:
69-
async for sse in event_source.aiter_sse():
70-
logger.debug(f"Received SSE event: {sse.event}")
71-
match sse.event:
72-
case "endpoint":
73-
endpoint_url = urljoin(url, sse.data)
74-
logger.info(
75-
f"Received endpoint URL: {endpoint_url}"
76-
)
77-
78-
url_parsed = urlparse(url)
79-
endpoint_parsed = urlparse(endpoint_url)
80-
if (
81-
url_parsed.netloc
82-
!= endpoint_parsed.netloc
83-
or url_parsed.scheme
84-
!= endpoint_parsed.scheme
85-
):
86-
error_msg = (
87-
"Endpoint origin does not match "
88-
f"connection origin: {endpoint_url}"
89-
)
90-
logger.error(error_msg)
91-
raise ValueError(error_msg)
92-
93-
task_status.started(endpoint_url)
94-
95-
case "message":
96-
try:
97-
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
98-
sse.data
99-
)
100-
logger.debug(
101-
"Received server message: "
102-
f"{message}"
103-
)
104-
except Exception as exc:
105-
logger.error(
106-
"Error parsing server message: "
107-
f"{exc}"
108-
)
109-
await read_stream_writer.send(exc)
110-
continue
111-
112-
session_message = SessionMessage(
113-
message=message
114-
)
115-
await read_stream_writer.send(
116-
session_message
117-
)
118-
case _:
119-
logger.warning(
120-
f"Unknown SSE event: {sse.event}"
121-
)
122-
except Exception as exc:
123-
logger.error(f"Error in sse_reader: {exc}")
124-
await read_stream_writer.send(exc)
125-
finally:
126-
await read_stream_writer.aclose()
127-
128-
async def post_writer(endpoint_url: str):
129-
try:
130-
async with write_stream_reader:
131-
async for session_message in write_stream_reader:
132-
logger.debug(
133-
f"Sending client message: {session_message}"
52+
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
53+
async with httpx.AsyncClient(headers=headers) as client:
54+
async with aconnect_sse(
55+
client,
56+
"GET",
57+
url,
58+
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
59+
) as event_source:
60+
event_source.response.raise_for_status()
61+
logger.debug("SSE connection established")
62+
63+
async def sse_reader(
64+
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
65+
):
66+
try:
67+
async for sse in event_source.aiter_sse():
68+
logger.debug(f"Received SSE event: {sse.event}")
69+
match sse.event:
70+
case "endpoint":
71+
endpoint_url = urljoin(url, sse.data)
72+
logger.info(
73+
f"Received endpoint URL: {endpoint_url}"
74+
)
75+
76+
url_parsed = urlparse(url)
77+
endpoint_parsed = urlparse(endpoint_url)
78+
if (
79+
url_parsed.netloc
80+
!= endpoint_parsed.netloc
81+
or url_parsed.scheme
82+
!= endpoint_parsed.scheme
83+
):
84+
error_msg = (
85+
"Endpoint origin does not match "
86+
f"connection origin: {endpoint_url}"
13487
)
135-
response = await client.post(
136-
endpoint_url,
137-
json=session_message.message.model_dump(
138-
by_alias=True,
139-
mode="json",
140-
exclude_none=True,
141-
),
88+
logger.error(error_msg)
89+
raise ValueError(error_msg)
90+
91+
task_status.started(endpoint_url)
92+
93+
case "message":
94+
try:
95+
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
96+
sse.data
14297
)
143-
response.raise_for_status()
14498
logger.debug(
145-
"Client message sent successfully: "
146-
f"{response.status_code}"
99+
"Received server message: "
100+
f"{message}"
147101
)
148-
except Exception as exc:
149-
logger.error(f"Error in post_writer: {exc}")
150-
finally:
151-
await write_stream.aclose()
152-
102+
except Exception as exc:
103+
logger.error(
104+
"Error parsing server message: "
105+
f"{exc}"
106+
)
107+
await read_stream_writer.send(exc)
108+
continue
109+
110+
session_message = SessionMessage(
111+
message=message
112+
)
113+
await read_stream_writer.send(
114+
session_message
115+
)
116+
case _:
117+
logger.warning(
118+
f"Unknown SSE event: {sse.event}"
119+
)
120+
except Exception as exc:
121+
logger.error(f"Error in sse_reader: {exc}")
122+
await read_stream_writer.send(exc)
123+
finally:
124+
await read_stream_writer.aclose()
125+
126+
async def post_writer(endpoint_url: str):
127+
try:
128+
async with write_stream_reader:
129+
async for session_message in write_stream_reader:
130+
logger.debug(
131+
f"Sending client message: {session_message}"
132+
)
133+
response = await client.post(
134+
endpoint_url,
135+
json=session_message.message.model_dump(
136+
by_alias=True,
137+
mode="json",
138+
exclude_none=True,
139+
),
140+
)
141+
response.raise_for_status()
142+
logger.debug(
143+
"Client message sent successfully: "
144+
f"{response.status_code}"
145+
)
146+
except Exception as exc:
147+
logger.error(f"Error in post_writer: {exc}")
148+
finally:
149+
await write_stream.aclose()
150+
151+
try:
152+
async with anyio.create_task_group() as tg:
153153
endpoint_url = await tg.start(sse_reader)
154154
logger.info(
155155
f"Starting post writer with endpoint URL: {endpoint_url}"
156156
)
157157
tg.start_soon(post_writer, endpoint_url)
158-
159-
try:
160-
yield read_stream, write_stream
161-
finally:
162-
tg.cancel_scope.cancel()
163-
finally:
164-
await read_stream_writer.aclose()
165-
await write_stream.aclose()
166-
await read_stream.aclose()
167-
await write_stream_reader.aclose()
158+
159+
# Move streams outside
160+
yield read_stream, write_stream
161+
finally:
162+
await read_stream_writer.aclose()
163+
await write_stream.aclose()
164+
await read_stream.aclose()
165+
await write_stream_reader.aclose()

0 commit comments

Comments
 (0)