Skip to content

Commit e21d514

Browse files
committed
make int tests better
1 parent 46b78f2 commit e21d514

File tree

5 files changed

+260
-188
lines changed

5 files changed

+260
-188
lines changed

src/mcp/client/sse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ async def sse_reader(
9898
await read_stream_writer.send(exc)
9999
continue
100100

101-
session_message = SessionMessage(
102-
message=message
103-
)
101+
session_message = SessionMessage(message=message)
104102
await read_stream_writer.send(session_message)
105103
case _:
106104
logger.warning(
@@ -150,3 +148,6 @@ async def post_writer(endpoint_url: str):
150148
finally:
151149
await read_stream_writer.aclose()
152150
await write_stream.aclose()
151+
await read_stream.aclose()
152+
await write_stream_reader.aclose()
153+

src/mcp/server/message_queue/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def publish_message(
2828
2929
Args:
3030
session_id: The UUID of the session this message is for
31-
message: The message to publish (JSONRPCMessage or str for invalid JSON)
31+
message: The message to publish (SessionMessage or str for invalid JSON)
3232
3333
Returns:
3434
bool: True if message was published, False if session not found

src/mcp/server/sse.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send):
152152
)(scope, receive, send)
153153
await read_stream_writer.aclose()
154154
await write_stream_reader.aclose()
155+
await sse_stream_writer.aclose()
156+
await sse_stream_reader.aclose()
155157
logging.debug(f"Client session disconnected {session_id}")
156158

157159
logger.debug("Starting SSE response task")

tests/server/message_dispatch/test_redis.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,6 @@
1010
from mcp.shared.message import SessionMessage
1111

1212

13-
@pytest.mark.anyio
14-
async def test_session_exists(message_dispatch):
15-
"""Test session existence check."""
16-
session_id = uuid4()
17-
18-
# Initially session should not exist
19-
assert not await message_dispatch.session_exists(session_id)
20-
21-
# After subscribing, session should exist
22-
async with message_dispatch.subscribe(session_id, AsyncMock()):
23-
assert await message_dispatch.session_exists(session_id)
24-
25-
# After unsubscribing, session should not exist
26-
assert not await message_dispatch.session_exists(session_id)
27-
28-
29-
@pytest.mark.anyio
30-
async def test_session_ttl(message_dispatch):
31-
"""Test that session has proper TTL set."""
32-
session_id = uuid4()
33-
34-
async with message_dispatch.subscribe(session_id, AsyncMock()):
35-
session_key = message_dispatch._session_key(session_id)
36-
ttl = await message_dispatch._redis.ttl(session_key) # type: ignore
37-
assert ttl > 0
38-
assert ttl <= message_dispatch._session_ttl
39-
4013

4114
@pytest.mark.anyio
4215
async def test_session_heartbeat(message_dispatch):
@@ -129,12 +102,12 @@ async def test_publish_message_invalid_json(message_dispatch):
129102

130103

131104
@pytest.mark.anyio
132-
async def test_publish_to_nonexistent_session(message_dispatch):
105+
async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch):
133106
"""Test publishing to a session that doesn't exist."""
134107
session_id = uuid4()
135-
message = types.JSONRPCMessage.model_validate(
108+
message = SessionMessage(message=types.JSONRPCMessage.model_validate(
136109
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
137-
)
110+
))
138111

139112
published = await message_dispatch.publish_message(session_id, message)
140113
assert not published

0 commit comments

Comments
 (0)