Skip to content

Commit 9664c8a

Browse files
committed
fmt
1 parent 2437e46 commit 9664c8a

File tree

6 files changed

+91
-76
lines changed

6 files changed

+91
-76
lines changed

src/mcp/server/message_queue/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,4 @@ async def session_exists(self, session_id: UUID) -> bool:
113113

114114
async def close(self) -> None:
115115
"""Close the message dispatch."""
116-
pass
116+
pass

src/mcp/server/message_queue/redis.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ class RedisMessageDispatch:
3030
"""
3131

3232
def __init__(
33-
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:pubsub:",
34-
session_ttl: int = 3600 # 1 hour default TTL for sessions
33+
self,
34+
redis_url: str = "redis://localhost:6379/0",
35+
prefix: str = "mcp:pubsub:",
36+
session_ttl: int = 3600, # 1 hour default TTL for sessions
3537
) -> None:
3638
"""Initialize Redis message dispatch.
3739
@@ -51,8 +53,8 @@ def __init__(
5153
logger.debug(f"Redis message dispatch initialized: {redis_url}")
5254

5355
async def close(self):
54-
await self._pubsub.aclose() # type: ignore
55-
await self._redis.aclose() # type: ignore
56+
await self._pubsub.aclose() # type: ignore
57+
await self._redis.aclose() # type: ignore
5658

5759
def _session_channel(self, session_id: UUID) -> str:
5860
"""Get the Redis channel for a session."""
@@ -67,7 +69,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6769
"""Request-scoped context manager that subscribes to messages for a session."""
6870
session_key = self._session_key(session_id)
6971
await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore
70-
72+
7173
channel = self._session_channel(session_id)
7274
await self._pubsub.subscribe(channel) # type: ignore
7375

@@ -193,4 +195,4 @@ async def publish_message(
193195
async def session_exists(self, session_id: UUID) -> bool:
194196
"""Check if a session exists."""
195197
session_key = self._session_key(session_id)
196-
return bool(await self._redis.exists(session_key)) # type: ignore
198+
return bool(await self._redis.exists(session_key)) # type: ignore
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
# Message queue tests module
1+
# Message queue tests module

tests/server/message_queue/test_redis.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def redis_dispatch():
2222
# Mock the redis module entirely within RedisMessageDispatch
2323
with patch("mcp.server.message_queue.redis.redis", fake_redis.FakeRedis):
2424
from mcp.server.message_queue.redis import RedisMessageDispatch
25-
25+
2626
dispatch = RedisMessageDispatch(session_ttl=5) # Shorter TTL for testing
2727
try:
2828
yield dispatch
@@ -50,7 +50,7 @@ async def test_session_exists(redis_dispatch):
5050
async def test_session_ttl(redis_dispatch):
5151
"""Test that session has proper TTL set."""
5252
session_id = uuid4()
53-
53+
5454
async with redis_dispatch.subscribe(session_id, AsyncMock()):
5555
session_key = redis_dispatch._session_key(session_id)
5656
ttl = await redis_dispatch._redis.ttl(session_key) # type: ignore
@@ -62,17 +62,17 @@ async def test_session_ttl(redis_dispatch):
6262
async def test_session_heartbeat(redis_dispatch):
6363
"""Test that session heartbeat refreshes TTL."""
6464
session_id = uuid4()
65-
65+
6666
async with redis_dispatch.subscribe(session_id, AsyncMock()):
6767
session_key = redis_dispatch._session_key(session_id)
68-
68+
6969
# Initial TTL
7070
initial_ttl = await redis_dispatch._redis.ttl(session_key) # type: ignore
7171
assert initial_ttl > 0
72-
72+
7373
# Wait for heartbeat to run
7474
await anyio.sleep(redis_dispatch._session_ttl / 2 + 0.5)
75-
75+
7676
# TTL should be refreshed
7777
refreshed_ttl = await redis_dispatch._redis.ttl(session_key) # type: ignore
7878
assert refreshed_ttl > 0
@@ -237,12 +237,12 @@ async def test_session_cancellation_isolation(redis_dispatch):
237237
"""Test that cancelling one session doesn't affect other sessions."""
238238
session1 = uuid4()
239239
session2 = uuid4()
240-
240+
241241
# Create a blocking callback for session1 to ensure it's running when cancelled
242242
session1_event = anyio.Event()
243243
session1_started = anyio.Event()
244244
session1_cancelled = False
245-
245+
246246
async def blocking_callback1(msg):
247247
session1_started.set()
248248
try:
@@ -258,45 +258,46 @@ async def blocking_callback1(msg):
258258
async with redis_dispatch.subscribe(session2, callback2):
259259
# Start session1 with a blocking callback
260260
async with anyio.create_task_group() as tg:
261+
261262
async def session1_runner():
262263
async with redis_dispatch.subscribe(session1, blocking_callback1):
263264
# Publish a message to trigger the blocking callback
264265
message = types.JSONRPCMessage.model_validate(
265266
{"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1}
266267
)
267268
await redis_dispatch.publish_message(session1, message)
268-
269+
269270
# Wait for the callback to start
270271
await session1_started.wait()
271-
272+
272273
# Keep the context alive while we test cancellation
273274
await anyio.sleep_forever()
274-
275+
275276
tg.start_soon(session1_runner)
276-
277+
277278
# Wait for session1's callback to start
278279
await session1_started.wait()
279-
280+
280281
# Cancel session1
281282
tg.cancel_scope.cancel()
282-
283+
283284
# Give some time for cancellation to propagate
284285
await anyio.sleep(0.1)
285-
286+
286287
# Verify session1 was cancelled
287288
assert session1_cancelled
288289
assert session1 not in redis_dispatch._session_state
289-
290+
290291
# Verify session2 is still active and can receive messages
291292
assert await redis_dispatch.session_exists(session2)
292293
message2 = types.JSONRPCMessage.model_validate(
293294
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
294295
)
295296
await redis_dispatch.publish_message(session2, message2)
296-
297+
297298
# Give some time for the message to be processed
298299
await anyio.sleep(0.1)
299-
300+
300301
# Verify session2 received the message
301302
callback2.assert_called_once()
302303
call_args = callback2.call_args[0][0]
@@ -308,80 +309,80 @@ async def test_listener_task_handoff_on_cancellation(redis_dispatch):
308309
"""Test that the single listening task is properly handed off when a session is cancelled."""
309310
session1 = uuid4()
310311
session2 = uuid4()
311-
312+
312313
session1_messages_received = 0
313314
session2_messages_received = 0
314-
315+
315316
async def callback1(msg):
316317
nonlocal session1_messages_received
317318
session1_messages_received += 1
318-
319+
319320
async def callback2(msg):
320321
nonlocal session2_messages_received
321322
session2_messages_received += 1
322-
323+
323324
# Create a cancel scope for session1
324325
async with anyio.create_task_group() as tg:
325326
session1_cancel_scope: anyio.CancelScope | None = None
326-
327+
327328
async def session1_runner():
328329
nonlocal session1_cancel_scope
329330
with anyio.CancelScope() as cancel_scope:
330331
session1_cancel_scope = cancel_scope
331332
async with redis_dispatch.subscribe(session1, callback1):
332333
# Keep session alive until cancelled
333334
await anyio.sleep_forever()
334-
335+
335336
# Start session1
336337
tg.start_soon(session1_runner)
337-
338+
338339
# Wait for session1 to be established
339340
await anyio.sleep(0.1)
340341
assert session1 in redis_dispatch._session_state
341-
342+
342343
# Send message to session1 to verify it's working
343344
message1 = types.JSONRPCMessage.model_validate(
344345
{"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1}
345346
)
346347
await redis_dispatch.publish_message(session1, message1)
347348
await anyio.sleep(0.1)
348349
assert session1_messages_received == 1
349-
350+
350351
# Start session2 while session1 is still active
351352
async with redis_dispatch.subscribe(session2, callback2):
352353
# Both sessions should be active
353354
assert session1 in redis_dispatch._session_state
354355
assert session2 in redis_dispatch._session_state
355-
356+
356357
# Cancel session1
357358
assert session1_cancel_scope is not None
358359
session1_cancel_scope.cancel()
359-
360+
360361
# Wait for cancellation to complete
361362
await anyio.sleep(0.1)
362-
363+
363364
# Session1 should be gone, session2 should remain
364365
assert session1 not in redis_dispatch._session_state
365366
assert session2 in redis_dispatch._session_state
366-
367+
367368
# Send message to session2 to verify the listener was handed off
368369
message2 = types.JSONRPCMessage.model_validate(
369370
{"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2}
370371
)
371372
await redis_dispatch.publish_message(session2, message2)
372373
await anyio.sleep(0.1)
373-
374+
374375
# Session2 should have received the message
375376
assert session2_messages_received == 1
376-
377+
377378
# Session1 shouldn't receive any more messages
378379
assert session1_messages_received == 1
379-
380+
380381
# Send another message to verify the listener is still working
381382
message3 = types.JSONRPCMessage.model_validate(
382383
{"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3}
383384
)
384385
await redis_dispatch.publish_message(session2, message3)
385386
await anyio.sleep(0.1)
386-
387-
assert session2_messages_received == 2
387+
388+
assert session2_messages_received == 2

0 commit comments

Comments
 (0)