Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions src/conduit/transport/streamable_http/server/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Session management for streamable HTTP transport."""

import logging
import secrets
import uuid

logger = logging.getLogger(__name__)
Expand All @@ -10,13 +9,16 @@
class SessionManager:
"""Manages client sessions and their lifecycle.

Always assigns session IDs to simplify the protocol implementation.
Maintains bidirectional mapping between sessions and clients.
Maintains a simple mapping from session IDs to client IDs.
Session IDs are UUIDs that comply with the MCP streamable HTTP spec.
"""

def __init__(self) -> None:
self._sessions: dict[str, str] = {} # session_id -> client_id
self._client_sessions: dict[str, str] = {} # client_id -> session_id

# ================================
# Creation
# ================================

def create_session(self) -> tuple[str, str]:
"""Create a new session and client ID pair.
Expand All @@ -25,14 +27,17 @@ def create_session(self) -> tuple[str, str]:
Tuple of (client_id, session_id)
"""
client_id = str(uuid.uuid4())
session_id = self._generate_session_id()
session_id = str(uuid.uuid4())

self._sessions[session_id] = client_id
self._client_sessions[client_id] = session_id

logger.debug(f"Created session {session_id} for client {client_id}")
return client_id, session_id

# ================================
# Access
# ================================

def get_client_id(self, session_id: str) -> str | None:
"""Get client ID for a session.

Expand All @@ -45,32 +50,35 @@ def get_session_id(self, client_id: str) -> str | None:

Returns None if client doesn't have a session.
"""
return self._client_sessions.get(client_id)
for session_id, stored_client_id in self._sessions.items():
if stored_client_id == client_id:
return session_id
return None

# ================================
# Existence
# ================================

def session_exists(self, session_id: str) -> bool:
"""Check if session exists."""
return session_id in self._sessions

def terminate_session(self, session_id: str) -> str | None:
# ================================
# Termination
# ================================

def terminate_session(self, session_id: str) -> bool:
"""Terminate a session.

Returns the client_id that was terminated, or None if session didn't exist.
Returns True if session existed and was terminated, False otherwise.
"""
if session_id not in self._sessions:
return None

client_id = self._sessions[session_id]
del self._sessions[session_id]
del self._client_sessions[client_id]

logger.debug(f"Terminated session {session_id} for client {client_id}")
return client_id

def _generate_session_id(self) -> str:
"""Generate cryptographically secure session ID."""
return secrets.token_urlsafe(32)
client_id = self._sessions.pop(session_id, None)
if client_id is not None:
logger.debug(f"Terminated session {session_id} for client {client_id}")
return True
return False

def terminate_all_sessions(self) -> None:
"""Terminate all sessions."""
for session_id in list(self._sessions):
self.terminate_session(session_id)
self._sessions.clear()
logger.debug("Terminated all sessions")
179 changes: 73 additions & 106 deletions src/conduit/transport/streamable_http/server/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import uuid
from typing import Any, AsyncIterator

from conduit.shared.message_parser import MessageParser

logger = logging.getLogger(__name__)


Expand All @@ -17,28 +15,37 @@ def __init__(self, stream_id: str, client_id: str, request_id: str | int):
self.client_id = client_id
self.request_id = request_id
self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
self._message_parser = MessageParser()

async def send_message(self, message: dict[str, Any]) -> None:
"""Send message on this stream."""
await self._message_queue.put(message)

async def send_response(self, response: dict[str, Any]) -> None:
"""Send final response and mark for auto-close."""
await self.send_message(response)

async def close(self) -> None:
"""Explicitly close the stream."""
# Send sentinel to stop the generator
await self._message_queue.put({"__close__": True})
logger.debug(f"Manually closed stream {self.stream_id}")

def is_response(self, message: dict[str, Any]) -> bool:
"""Check if message is the final response."""
return self._message_parser.is_valid_response(message)
"""Check if message is a JSON-RPC response."""
id_value = message.get("id")
has_valid_id = (
id_value is not None
and isinstance(id_value, (int, str))
and not isinstance(id_value, bool)
)
has_result = "result" in message
has_error = "error" in message
return has_valid_id and (has_result ^ has_error)

async def event_generator(self) -> AsyncIterator[str]:
"""Generate SSE events for this stream."""
"""Generate SSE events for this stream.

Automatically closes after sending a response.

Yields:
str: SSE event data
"""
try:
while True:
message = await self._message_queue.get()
Expand Down Expand Up @@ -69,68 +76,57 @@ class StreamManager:
"""Manages multiple SSE streams with routing and cleanup."""

def __init__(self):
self._streams: dict[str, SSEStream] = {} # stream_id -> stream
self._client_streams: dict[str, set[str]] = {} # client_id -> set of stream_ids

async def create_request_stream(self, client_id: str, request_id: str) -> SSEStream:
"""Create and register a new stream for a specific request."""
stream_id = f"{client_id}:request:{request_id}"
return await self._create_and_register_stream(
stream_id, client_id, str(request_id)
)
self._client_streams: dict[
str, set[SSEStream]
] = {} # client_id -> set of streams

async def create_server_stream(self, client_id: str) -> SSEStream:
"""Create and register a new stream for server-initiated messages."""
async def create_stream(
self, client_id: str, request_id: str | None = None
) -> SSEStream:
"""Create and register a new stream."""
stream_id = str(uuid.uuid4())
stream = SSEStream(stream_id, client_id, request_id or "GET")

stream_uuid = str(uuid.uuid4())[:8]
stream_id = f"{client_id}:server:{stream_uuid}"
# Track by client
self._client_streams.setdefault(client_id, set()).add(stream)

return await self._create_and_register_stream(stream_id, client_id, None)
logger.debug(f"Created stream {stream_id} for client {client_id}")
return stream

async def send_to_existing_stream(
self,
client_id: str,
message: dict[str, Any],
originating_request_id: str | None = None,
originating_request_id: str | int | None = None,
) -> bool:
"""Send message to existing stream if available. Returns True if sent."""
if originating_request_id:
stream = self.get_request_stream(client_id, originating_request_id)
if stream:
return await self._send_to_stream(stream, message, auto_cleanup=True)
else:
return False
else:
server_streams = [
sid
for sid in self._client_streams.get(client_id, set())
if sid.startswith(f"{client_id}:server:")
]

if server_streams:
# For now, just use the first available server stream
# Don't auto-cleanup even though responses should not be sent on
# server streams
return await self._send_to_stream(
server_streams[0], message, auto_cleanup=False
)

return False

async def _create_and_register_stream(
self, stream_id: str, client_id: str, request_id: str | int | None
) -> SSEStream:
"""Create and register a stream."""
stream = SSEStream(stream_id, client_id, request_id or "GET")
"""Send message to existing stream if available.

# Register it
self._streams[stream_id] = stream
Args:
client_id: The client ID
message: The message to send
originating_request_id: The ID of the originating request. If provided,
the message will be sent to the stream with the matching request ID.
If not provided, the message will be sent to the first available
stream (e.g. a stream created by a GET request).

# Track by client
self._client_streams.setdefault(client_id, set()).add(stream_id)
Returns:
True if message was sent, False otherwise
"""
streams = self._client_streams.get(client_id, set())

logger.debug(f"Created stream {stream_id} for client {client_id}")
return stream
if originating_request_id:
for stream in streams:
if stream.request_id == originating_request_id:
return await self._send_to_stream(
stream, message, auto_cleanup=True
)
else:
# Use any available stream (first one)
if streams:
stream = next(iter(streams))
return await self._send_to_stream(stream, message, auto_cleanup=False)

return False

async def _send_to_stream(
self, stream: SSEStream, message: dict[str, Any], auto_cleanup: bool
Expand All @@ -139,70 +135,41 @@ async def _send_to_stream(
await stream.send_message(message)

if auto_cleanup and stream.is_response(message):
await self._cleanup_stream(stream.stream_id)
await self._cleanup_stream(stream)

return True

async def cleanup_client_streams(self, client_id: str) -> None:
"""Clean up all streams for a client."""
if client_id not in self._client_streams:
return
streams = self._client_streams.get(client_id, set()).copy()
for stream in streams:
await self._cleanup_stream(stream)

stream_ids = self._client_streams[client_id].copy()
for stream_id in stream_ids:
await self._cleanup_stream(stream_id)

logger.debug(f"Cleaned up {len(stream_ids)} streams for client {client_id}")
logger.debug(f"Cleaned up {len(streams)} streams for client {client_id}")

def get_stream_by_id(self, stream_id: str) -> SSEStream | None:
"""Get stream by exact stream ID."""
return self._streams.get(stream_id)

def get_request_stream(self, client_id: str, request_id: str) -> SSEStream | None:
"""Get specific request stream."""
stream_id = f"{client_id}:request:{request_id}"
return self._streams.get(stream_id)

def get_server_streams(self, client_id: str) -> list[SSEStream]:
"""Get all server streams for a client."""
server_streams = []
for stream_id in self._client_streams.get(client_id, set()):
if stream_id.startswith(f"{client_id}:server:"):
stream = self._streams.get(stream_id)
if stream:
server_streams.append(stream)
return server_streams

async def _cleanup_stream(self, stream_id: str) -> None:
"""Clean up a single stream."""
if stream_id not in self._streams:
return

stream = self._streams[stream_id]
for streams in self._client_streams.values():
for stream in streams:
if stream.stream_id == stream_id:
return stream
return None

async def _cleanup_stream(self, stream: SSEStream) -> None:
"""Clean up a single stream."""
# Close the stream (sends sentinel)
await stream.close()

# Remove from tracking
del self._streams[stream_id]

# Remove from client tracking
if stream.client_id in self._client_streams:
self._client_streams[stream.client_id].discard(stream_id)
self._client_streams[stream.client_id].discard(stream)
if not self._client_streams[stream.client_id]:
del self._client_streams[stream.client_id]

logger.debug(f"Cleaned up stream {stream_id}")

def get_active_stream_count(self) -> int:
"""Get number of active streams (for debugging/metrics)."""
return len(self._streams)

def get_client_stream_count(self, client_id: str) -> int:
"""Get number of active streams for a client."""
return len(self._client_streams.get(client_id, set()))
logger.debug(f"Cleaned up stream {stream.stream_id}")

async def close_all_streams(self) -> None:
"""Close all streams."""
for stream_id in list(self._streams):
await self._cleanup_stream(stream_id)
for streams in list(self._client_streams.values()):
for stream in streams.copy():
await self._cleanup_stream(stream)
Loading