diff --git a/src/conduit/transport/streamable_http/server/session_manager.py b/src/conduit/transport/streamable_http/server/session_manager.py index b50f6dd..4187d54 100644 --- a/src/conduit/transport/streamable_http/server/session_manager.py +++ b/src/conduit/transport/streamable_http/server/session_manager.py @@ -1,7 +1,6 @@ """Session management for streamable HTTP transport.""" import logging -import secrets import uuid logger = logging.getLogger(__name__) @@ -10,13 +9,16 @@ class SessionManager: """Manages client sessions and their lifecycle. - Always assigns session IDs to simplify the protocol implementation. - Maintains bidirectional mapping between sessions and clients. + Maintains a simple mapping from session IDs to client IDs. + Session IDs are UUIDs that comply with the MCP streamable HTTP spec. """ def __init__(self) -> None: self._sessions: dict[str, str] = {} # session_id -> client_id - self._client_sessions: dict[str, str] = {} # client_id -> session_id + + # ================================ + # Creation + # ================================ def create_session(self) -> tuple[str, str]: """Create a new session and client ID pair. @@ -25,14 +27,17 @@ def create_session(self) -> tuple[str, str]: Tuple of (client_id, session_id) """ client_id = str(uuid.uuid4()) - session_id = self._generate_session_id() + session_id = str(uuid.uuid4()) self._sessions[session_id] = client_id - self._client_sessions[client_id] = session_id logger.debug(f"Created session {session_id} for client {client_id}") return client_id, session_id + # ================================ + # Access + # ================================ + def get_client_id(self, session_id: str) -> str | None: """Get client ID for a session. @@ -45,32 +50,35 @@ def get_session_id(self, client_id: str) -> str | None: Returns None if client doesn't have a session. """ - return self._client_sessions.get(client_id) + for session_id, stored_client_id in self._sessions.items(): + if stored_client_id == client_id: + return session_id + return None + + # ================================ + # Existence + # ================================ def session_exists(self, session_id: str) -> bool: """Check if session exists.""" return session_id in self._sessions - def terminate_session(self, session_id: str) -> str | None: + # ================================ + # Termination + # ================================ + + def terminate_session(self, session_id: str) -> bool: """Terminate a session. - Returns the client_id that was terminated, or None if session didn't exist. + Returns True if session existed and was terminated, False otherwise. """ - if session_id not in self._sessions: - return None - - client_id = self._sessions[session_id] - del self._sessions[session_id] - del self._client_sessions[client_id] - - logger.debug(f"Terminated session {session_id} for client {client_id}") - return client_id - - def _generate_session_id(self) -> str: - """Generate cryptographically secure session ID.""" - return secrets.token_urlsafe(32) + client_id = self._sessions.pop(session_id, None) + if client_id is not None: + logger.debug(f"Terminated session {session_id} for client {client_id}") + return True + return False def terminate_all_sessions(self) -> None: """Terminate all sessions.""" - for session_id in list(self._sessions): - self.terminate_session(session_id) + self._sessions.clear() + logger.debug("Terminated all sessions") diff --git a/src/conduit/transport/streamable_http/server/streams.py b/src/conduit/transport/streamable_http/server/streams.py index 0ec348f..a756661 100644 --- a/src/conduit/transport/streamable_http/server/streams.py +++ b/src/conduit/transport/streamable_http/server/streams.py @@ -4,8 +4,6 @@ import uuid from typing import Any, AsyncIterator -from conduit.shared.message_parser import MessageParser - logger = logging.getLogger(__name__) @@ -17,16 +15,11 @@ def __init__(self, stream_id: str, client_id: str, request_id: str | int): self.client_id = client_id self.request_id = request_id self._message_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() - self._message_parser = MessageParser() async def send_message(self, message: dict[str, Any]) -> None: """Send message on this stream.""" await self._message_queue.put(message) - async def send_response(self, response: dict[str, Any]) -> None: - """Send final response and mark for auto-close.""" - await self.send_message(response) - async def close(self) -> None: """Explicitly close the stream.""" # Send sentinel to stop the generator @@ -34,11 +27,25 @@ async def close(self) -> None: logger.debug(f"Manually closed stream {self.stream_id}") def is_response(self, message: dict[str, Any]) -> bool: - """Check if message is the final response.""" - return self._message_parser.is_valid_response(message) + """Check if message is a JSON-RPC response.""" + id_value = message.get("id") + has_valid_id = ( + id_value is not None + and isinstance(id_value, (int, str)) + and not isinstance(id_value, bool) + ) + has_result = "result" in message + has_error = "error" in message + return has_valid_id and (has_result ^ has_error) async def event_generator(self) -> AsyncIterator[str]: - """Generate SSE events for this stream.""" + """Generate SSE events for this stream. + + Automatically closes after sending a response. + + Yields: + str: SSE event data + """ try: while True: message = await self._message_queue.get() @@ -69,68 +76,57 @@ class StreamManager: """Manages multiple SSE streams with routing and cleanup.""" def __init__(self): - self._streams: dict[str, SSEStream] = {} # stream_id -> stream - self._client_streams: dict[str, set[str]] = {} # client_id -> set of stream_ids - - async def create_request_stream(self, client_id: str, request_id: str) -> SSEStream: - """Create and register a new stream for a specific request.""" - stream_id = f"{client_id}:request:{request_id}" - return await self._create_and_register_stream( - stream_id, client_id, str(request_id) - ) + self._client_streams: dict[ + str, set[SSEStream] + ] = {} # client_id -> set of streams - async def create_server_stream(self, client_id: str) -> SSEStream: - """Create and register a new stream for server-initiated messages.""" + async def create_stream( + self, client_id: str, request_id: str | None = None + ) -> SSEStream: + """Create and register a new stream.""" + stream_id = str(uuid.uuid4()) + stream = SSEStream(stream_id, client_id, request_id or "GET") - stream_uuid = str(uuid.uuid4())[:8] - stream_id = f"{client_id}:server:{stream_uuid}" + # Track by client + self._client_streams.setdefault(client_id, set()).add(stream) - return await self._create_and_register_stream(stream_id, client_id, None) + logger.debug(f"Created stream {stream_id} for client {client_id}") + return stream async def send_to_existing_stream( self, client_id: str, message: dict[str, Any], - originating_request_id: str | None = None, + originating_request_id: str | int | None = None, ) -> bool: - """Send message to existing stream if available. Returns True if sent.""" - if originating_request_id: - stream = self.get_request_stream(client_id, originating_request_id) - if stream: - return await self._send_to_stream(stream, message, auto_cleanup=True) - else: - return False - else: - server_streams = [ - sid - for sid in self._client_streams.get(client_id, set()) - if sid.startswith(f"{client_id}:server:") - ] - - if server_streams: - # For now, just use the first available server stream - # Don't auto-cleanup even though responses should not be sent on - # server streams - return await self._send_to_stream( - server_streams[0], message, auto_cleanup=False - ) - - return False - - async def _create_and_register_stream( - self, stream_id: str, client_id: str, request_id: str | int | None - ) -> SSEStream: - """Create and register a stream.""" - stream = SSEStream(stream_id, client_id, request_id or "GET") + """Send message to existing stream if available. - # Register it - self._streams[stream_id] = stream + Args: + client_id: The client ID + message: The message to send + originating_request_id: The ID of the originating request. If provided, + the message will be sent to the stream with the matching request ID. + If not provided, the message will be sent to the first available + stream (e.g. a stream created by a GET request). - # Track by client - self._client_streams.setdefault(client_id, set()).add(stream_id) + Returns: + True if message was sent, False otherwise + """ + streams = self._client_streams.get(client_id, set()) - logger.debug(f"Created stream {stream_id} for client {client_id}") - return stream + if originating_request_id: + for stream in streams: + if stream.request_id == originating_request_id: + return await self._send_to_stream( + stream, message, auto_cleanup=True + ) + else: + # Use any available stream (first one) + if streams: + stream = next(iter(streams)) + return await self._send_to_stream(stream, message, auto_cleanup=False) + + return False async def _send_to_stream( self, stream: SSEStream, message: dict[str, Any], auto_cleanup: bool @@ -139,70 +135,41 @@ async def _send_to_stream( await stream.send_message(message) if auto_cleanup and stream.is_response(message): - await self._cleanup_stream(stream.stream_id) + await self._cleanup_stream(stream) return True async def cleanup_client_streams(self, client_id: str) -> None: """Clean up all streams for a client.""" - if client_id not in self._client_streams: - return + streams = self._client_streams.get(client_id, set()).copy() + for stream in streams: + await self._cleanup_stream(stream) - stream_ids = self._client_streams[client_id].copy() - for stream_id in stream_ids: - await self._cleanup_stream(stream_id) - - logger.debug(f"Cleaned up {len(stream_ids)} streams for client {client_id}") + logger.debug(f"Cleaned up {len(streams)} streams for client {client_id}") def get_stream_by_id(self, stream_id: str) -> SSEStream | None: """Get stream by exact stream ID.""" - return self._streams.get(stream_id) - - def get_request_stream(self, client_id: str, request_id: str) -> SSEStream | None: - """Get specific request stream.""" - stream_id = f"{client_id}:request:{request_id}" - return self._streams.get(stream_id) - - def get_server_streams(self, client_id: str) -> list[SSEStream]: - """Get all server streams for a client.""" - server_streams = [] - for stream_id in self._client_streams.get(client_id, set()): - if stream_id.startswith(f"{client_id}:server:"): - stream = self._streams.get(stream_id) - if stream: - server_streams.append(stream) - return server_streams - - async def _cleanup_stream(self, stream_id: str) -> None: - """Clean up a single stream.""" - if stream_id not in self._streams: - return - - stream = self._streams[stream_id] + for streams in self._client_streams.values(): + for stream in streams: + if stream.stream_id == stream_id: + return stream + return None + async def _cleanup_stream(self, stream: SSEStream) -> None: + """Clean up a single stream.""" # Close the stream (sends sentinel) await stream.close() - # Remove from tracking - del self._streams[stream_id] - # Remove from client tracking if stream.client_id in self._client_streams: - self._client_streams[stream.client_id].discard(stream_id) + self._client_streams[stream.client_id].discard(stream) if not self._client_streams[stream.client_id]: del self._client_streams[stream.client_id] - logger.debug(f"Cleaned up stream {stream_id}") - - def get_active_stream_count(self) -> int: - """Get number of active streams (for debugging/metrics).""" - return len(self._streams) - - def get_client_stream_count(self, client_id: str) -> int: - """Get number of active streams for a client.""" - return len(self._client_streams.get(client_id, set())) + logger.debug(f"Cleaned up stream {stream.stream_id}") async def close_all_streams(self) -> None: """Close all streams.""" - for stream_id in list(self._streams): - await self._cleanup_stream(stream_id) + for streams in list(self._client_streams.values()): + for stream in streams.copy(): + await self._cleanup_stream(stream) diff --git a/src/conduit/transport/streamable_http/server/transport.py b/src/conduit/transport/streamable_http/server/transport.py index 8daf2d0..1006f60 100644 --- a/src/conduit/transport/streamable_http/server/transport.py +++ b/src/conduit/transport/streamable_http/server/transport.py @@ -107,18 +107,6 @@ async def stop(self) -> None: self._server.should_exit = True await self._server.shutdown() - async def close(self) -> None: - """Close the transport and clean up all resources. - - For HTTP transport, this stops the HTTP server and cleans up all sessions. - Safe to call multiple times. - """ - await self.stop() - - # Clean up sessions and streams - self._session_manager.terminate_all_sessions() - await self._stream_manager.close_all_streams() - # ================================ # Server Transport Interface # ================================ @@ -129,7 +117,21 @@ async def send( message: dict[str, Any], transport_context: TransportContext | None = None, ) -> None: - """Send message to specific client.""" + """Send message to specific client. + + Args: + client_id: Target client connection ID + message: JSON-RPC message to send + transport_context: Context for the transport. For example, this helps route + messages along specific streams. + + Raises: + ValueError: If client_id is not connected + ConnectionError: If connection failed during send + """ + if not self._session_manager.get_session_id(client_id): + raise ValueError(f"Client {client_id} is not connected") + originating_request_id = ( transport_context.originating_request_id if transport_context else None ) @@ -139,9 +141,7 @@ async def send( ): return - logger.warning( - f"No server streams available for client {client_id}, dropping message." - ) + raise ConnectionError(f"No active streams available for client {client_id}") def client_messages(self) -> AsyncIterator[ClientMessage]: """Stream of messages from all clients.""" @@ -153,6 +153,18 @@ async def disconnect_client(self, client_id: str) -> None: if session_id: self._session_manager.terminate_session(session_id) + async def close(self) -> None: + """Close the transport and clean up all resources. + + For HTTP transport, this stops the HTTP server and cleans up all sessions. + Safe to call multiple times. + """ + await self.stop() + + # Clean up sessions and streams + self._session_manager.terminate_all_sessions() + await self._stream_manager.close_all_streams() + # ================================ # HTTP Request Handlers # ================================ @@ -182,9 +194,14 @@ async def _handle_post_request(self, request: Request) -> Response: try: message_data = await request.json() if not isinstance(message_data, dict): - return Response("Invalid JSON", status_code=400) - except json.JSONDecodeError: - return Response("Invalid JSON", status_code=400) + return Response( + f"Invalid JSON: expected object, got {type(message_data).__name__}", + status_code=400, + ) + except json.JSONDecodeError as e: + return Response( + f"Invalid JSON: {e.msg} at position {e.pos}", status_code=400 + ) jsonrpc_error = self._validate_jsonrpc_message(message_data) if jsonrpc_error: @@ -225,19 +242,19 @@ async def _handle_get_request(self, request: Request) -> Response: if headers_error: return headers_error - # Get client from existing session (GET doesn't create sessions) + # First check: Is session ID present? session_id = request.headers.get("Mcp-Session-Id") - if not session_id or not self._session_manager.session_exists(session_id): - return Response("Missing or invalid session", status_code=400) + if not session_id: + return Response("Missing session ID", status_code=400) client_id = self._session_manager.get_client_id(session_id) if not client_id: - return Response("Client not found", status_code=404) + return Response("Invalid or expired session", status_code=404) headers = self._build_server_stream_headers(request, session_id) try: - stream = await self._stream_manager.create_server_stream(client_id) + stream = await self._stream_manager.create_stream(client_id) logger.debug( f"Created server stream {stream.stream_id} for client {client_id}" ) @@ -257,8 +274,7 @@ async def _handle_delete_request(self, request: Request) -> Response: if not session_id: return Response("Missing session ID", status_code=400) - client_id = self._session_manager.terminate_session(session_id) - if client_id is None: + if not self._session_manager.terminate_session(session_id): return Response("Session not found", status_code=404) return Response(status_code=200) @@ -270,14 +286,30 @@ async def _handle_delete_request(self, request: Request) -> Response: def _validate_protocol_headers(self, request: Request) -> Response | None: """Validate required MCP protocol headers (Protocol-Version, Accept, Origin).""" protocol_version = request.headers.get("MCP-Protocol-Version") - if protocol_version != PROTOCOL_VERSION: - logger.warning("Invalid MCP-Protocol-Version header") - return Response("Invalid MCP-Protocol-Version header", status_code=400) + if not protocol_version: + return Response( + f"Missing MCP-Protocol-Version header, expected: {PROTOCOL_VERSION}", + status_code=400, + ) + elif protocol_version != PROTOCOL_VERSION: + return Response( + f"Invalid MCP-Protocol-Version: {protocol_version}, expected:" + f" {PROTOCOL_VERSION}", + status_code=400, + ) accept = request.headers.get("Accept") + if not accept: + return Response( + "Missing Accept header, expected: text/event-stream, application/json", + status_code=400, + ) if "text/event-stream" not in accept or "application/json" not in accept: - logger.warning("Invalid Accept header") - return Response("Invalid Accept header", status_code=400) + return Response( + "Invalid Accept header, expected: text/event-stream, application/json" + f" (got: {accept})", + status_code=400, + ) # TODO: Implement proper Origin validation to prevent DNS rebinding attacks # For now, we accept all origins (development mode) @@ -312,11 +344,9 @@ def _validate_session( else: # All other messages MUST have a valid session ID if not session_id: - return Response( - "Missing required Mcp-Session-Id header", status_code=400 - ) + return Response("Missing Mcp-Session-Id header", status_code=400) if not self._session_manager.session_exists(session_id): - return Response("Invalid or expired session", status_code=404) + return Response("Invalid or expired Mcp-Session-Id", status_code=404) return None def _validate_jsonrpc_message(self, message_data: dict) -> Response | None: @@ -329,7 +359,9 @@ def _validate_jsonrpc_message(self, message_data: dict) -> Response | None: is_notification = self._message_parser.is_valid_notification(message_data) if not (is_request or is_response or is_notification): - return Response("Invalid JSON-RPC message", status_code=400) + return Response( + f"Invalid JSON-RPC message: {message_data}", status_code=400 + ) return None def _is_valid_origin(self, origin: str | None) -> bool: @@ -358,7 +390,7 @@ async def _get_or_create_client( Tuple of (client_id, session_id) Raises: - ValueError: If client ID is not found for the session + ValueError: If missing session ID or client not found """ session_id = request.headers.get("Mcp-Session-Id") is_initialize = ( @@ -370,9 +402,12 @@ async def _get_or_create_client( client_id, session_id = self._session_manager.create_session() return client_id, session_id else: + if session_id is None: + raise ValueError("Session ID is required for non-initialize requests") + client_id = self._session_manager.get_client_id(session_id) if not client_id: - raise ValueError(f"Client ID not found for session {session_id}") + raise ValueError("Client not found for session") return client_id, session_id # ================================ @@ -381,7 +416,7 @@ async def _get_or_create_client( async def _create_request_stream( self, client_id: str, request_id: str | int, headers: dict[str, str] - ) -> StreamingResponse: + ) -> StreamingResponse | Response: """Create SSE stream for a request. The stream will: @@ -389,13 +424,17 @@ async def _create_request_stream( 2. Send the final response to the original request 3. Auto-close after sending the response """ - stream = await self._stream_manager.create_request_stream( - client_id, str(request_id) - ) + try: + stream = await self._stream_manager.create_stream(client_id, request_id) - return StreamingResponse( - stream.event_generator(), media_type="text/event-stream", headers=headers - ) + return StreamingResponse( + stream.event_generator(), + media_type="text/event-stream", + headers=headers, + ) + except Exception as e: + logger.error(f"Failed to create request stream for client {client_id}: {e}") + return Response("Internal server error", status_code=500) # ================================ # Response Builders diff --git a/tests/transport/streamable_http/server/test_client_messages.py b/tests/transport/streamable_http/server/test_client_messages.py new file mode 100644 index 0000000..dee4180 --- /dev/null +++ b/tests/transport/streamable_http/server/test_client_messages.py @@ -0,0 +1,58 @@ +"""Tests for HttpServerTransport.client_messages method.""" + +import asyncio + +import pytest + +from conduit.transport.server import ClientMessage +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestHttpServerTransportClientMessages: + """Orchestration tests for HttpServerTransport.client_messages.""" + + @pytest.fixture + def transport(self): + """Create transport instance.""" + return HttpServerTransport() + + async def test_client_messages_yields_queued_messages(self, transport): + """Test that client_messages yields messages from the queue.""" + # Arrange + test_message = ClientMessage( + client_id="test-client", + payload={"jsonrpc": "2.0", "method": "test"}, + timestamp=123.456, + ) + + # Put message in queue + await transport._message_queue.put(test_message) + + # Act + iterator = transport.client_messages() + message = await iterator.__anext__() + + # Assert + assert message == test_message + assert message.client_id == "test-client" + assert message.payload == {"jsonrpc": "2.0", "method": "test"} + assert message.timestamp == 123.456 + + async def test_client_messages_handles_multiple_messages(self, transport): + """Test that client_messages handles multiple queued messages.""" + # Arrange + messages = [ + ClientMessage("client-1", {"method": "test1"}, 1.0), + ClientMessage("client-2", {"method": "test2"}, 2.0), + ClientMessage("client-3", {"method": "test3"}, 3.0), + ] + + # Queue all messages + for msg in messages: + await transport._message_queue.put(msg) + + # Act & Assert + iterator = transport.client_messages() + for expected_message in messages: + received_message = await asyncio.wait_for(iterator.__anext__(), timeout=1.0) + assert received_message == expected_message diff --git a/tests/transport/streamable_http/server/test_close.py b/tests/transport/streamable_http/server/test_close.py new file mode 100644 index 0000000..d337cc0 --- /dev/null +++ b/tests/transport/streamable_http/server/test_close.py @@ -0,0 +1,44 @@ +"""Tests for HttpServerTransport.close method.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestHttpServerTransportClose: + """Orchestration tests for HttpServerTransport.close.""" + + @pytest.fixture + def transport(self): + """Create transport instance with mocked dependencies.""" + # Arrange + transport = HttpServerTransport() + transport._session_manager = Mock() + transport._stream_manager = AsyncMock() + # Mock the stop method to avoid actual server operations + transport.stop = AsyncMock() + return transport + + async def test_close_stops_server_and_cleans_up(self, transport): + """Test that close properly stops server and cleans up resources.""" + # Arrange & Act + await transport.close() + + # Assert + transport.stop.assert_called_once() + transport._session_manager.terminate_all_sessions.assert_called_once() + transport._stream_manager.close_all_streams.assert_called_once() + + async def test_close_is_idempotent(self, transport): + """Test that close can be called multiple times safely.""" + # Arrange & Act + await transport.close() + await transport.close() + await transport.close() + + # Assert + assert transport.stop.call_count == 3 + assert transport._session_manager.terminate_all_sessions.call_count == 3 + assert transport._stream_manager.close_all_streams.call_count == 3 diff --git a/tests/transport/streamable_http/server/test_delete_handler.py b/tests/transport/streamable_http/server/test_delete_handler.py new file mode 100644 index 0000000..1e19dab --- /dev/null +++ b/tests/transport/streamable_http/server/test_delete_handler.py @@ -0,0 +1,62 @@ +"""Tests for HttpServerTransport DELETE request handler.""" + +from unittest.mock import Mock + +import pytest +from starlette.requests import Request + +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestDeleteRequestProcessing: + """Test DELETE request processing: headers, session validation.""" + + @pytest.fixture + def transport(self): + """Create transport instance.""" + # Arrange + return HttpServerTransport() + + async def test_valid_session_returns_200(self, transport): + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create request with valid session ID + request = Mock(spec=Request) + request.headers = {"Mcp-Session-Id": session_id} + + # Act + response = await transport._handle_delete_request(request) + + # Assert + assert response.status_code == 200 + assert response.body == b"" # No body content + + # Verify session was actually terminated + assert not transport._session_manager.session_exists(session_id) + + async def test_missing_session_id_returns_400(self, transport): + # Arrange + # Create request without session ID + request = Mock(spec=Request) + request.headers = {} + + # Act + response = await transport._handle_delete_request(request) + + # Assert + assert response.status_code == 400 + assert response.body.decode() == "Missing session ID" + + async def test_invalid_session_id_returns_404(self, transport): + # Arrange + # Create request with non-existent session ID + request = Mock(spec=Request) + request.headers = {"Mcp-Session-Id": "non-existent-session-123"} + + # Act + response = await transport._handle_delete_request(request) + + # Assert + assert response.status_code == 404 + assert response.body.decode() == "Session not found" diff --git a/tests/transport/streamable_http/server/test_disconnect_client.py b/tests/transport/streamable_http/server/test_disconnect_client.py new file mode 100644 index 0000000..0c0ff09 --- /dev/null +++ b/tests/transport/streamable_http/server/test_disconnect_client.py @@ -0,0 +1,59 @@ +"""Tests for HttpServerTransport.disconnect_client method.""" + +from unittest.mock import Mock + +import pytest + +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestHttpServerTransportDisconnectClient: + """Orchestration tests for HttpServerTransport.disconnect_client.""" + + @pytest.fixture + def transport(self): + """Create transport instance with mocked session manager.""" + # Arrange + transport = HttpServerTransport() + transport._session_manager = Mock() + return transport + + async def test_disconnect_client_terminates_existing_session(self, transport): + """Test disconnecting a client with an existing session.""" + # Arrange + client_id = "test-client-123" + session_id = "session-456" + transport._session_manager.get_session_id.return_value = session_id + + # Act + await transport.disconnect_client(client_id) + + # Assert + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._session_manager.terminate_session.assert_called_once_with(session_id) + + async def test_disconnect_client_handles_nonexistent_client(self, transport): + """Test disconnecting a client that doesn't exist.""" + # Arrange + client_id = "nonexistent-client" + transport._session_manager.get_session_id.return_value = None + + # Act + await transport.disconnect_client(client_id) + + # Assert + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._session_manager.terminate_session.assert_not_called() + + async def test_disconnect_client_handles_empty_session_id(self, transport): + """Test disconnecting when session ID is empty string.""" + # Arrange + client_id = "test-client" + transport._session_manager.get_session_id.return_value = "" + + # Act + await transport.disconnect_client(client_id) + + # Assert + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._session_manager.terminate_session.assert_not_called() diff --git a/tests/transport/streamable_http/server/test_get_handler.py b/tests/transport/streamable_http/server/test_get_handler.py new file mode 100644 index 0000000..6d7a859 --- /dev/null +++ b/tests/transport/streamable_http/server/test_get_handler.py @@ -0,0 +1,160 @@ +"""Tests for HttpServerTransport GET request handler.""" + +from unittest.mock import AsyncMock, Mock + +import pytest +from starlette.requests import Request +from starlette.responses import StreamingResponse + +from conduit.protocol.base import PROTOCOL_VERSION +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestGetHandler: + @pytest.fixture + def transport(self): + # Arrange + transport = HttpServerTransport() + # Keep real session manager, just mock stream manager + transport._stream_manager = AsyncMock() + return transport + + @pytest.fixture + def mock_stream(self): + """Create a mock stream object.""" + stream = Mock() + stream.stream_id = "test-stream-123" + stream.event_generator = Mock(return_value=iter(["data: test\n\n"])) + return stream + + async def test_get_request_success_creates_server_stream( + self, transport, mock_stream + ): + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Mock successful stream creation + transport._stream_manager.create_stream.return_value = mock_stream + + # Create request with proper headers + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + + # Act + response = await transport._handle_get_request(request) + + # Assert + # Verify response type and media type + assert isinstance(response, StreamingResponse) + assert response.media_type == "text/event-stream" + + # Verify key headers + assert response.headers["Content-Type"] == "text/event-stream" + assert response.headers["Mcp-Session-Id"] == session_id + assert response.headers["MCP-Protocol-Version"] == PROTOCOL_VERSION + + # Verify stream creation + transport._stream_manager.create_stream.assert_awaited_once_with(client_id) + + async def test_get_request_invalid_accept_header_returns_400(self, transport): + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create request with invalid Accept header (missing text/event-stream) + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "application/json", # Missing text/event-stream + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + + # Act + response = await transport._handle_get_request(request) + + # Assert + assert response.status_code == 400 + assert "Invalid Accept header" in response.body.decode() + + # Verify stream creation was never attempted + transport._stream_manager.create_stream.assert_not_awaited() + + async def test_get_request_missing_session_id_returns_400(self, transport): + # Arrange + # Create request with valid headers but no session ID + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + # Note: No Mcp-Session-Id header + } + request = Mock(spec=Request) + request.headers = headers + + # Act + response = await transport._handle_get_request(request) + + # Assert + assert response.status_code == 400 + assert "Missing session ID" in response.body.decode() + + # Verify stream creation was never attempted + transport._stream_manager.create_stream.assert_not_awaited() + + async def test_get_request_invalid_session_id_returns_404(self, transport): + # Arrange + # Create request with valid headers but non-existent session ID + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": "non-existent-session-123", + } + request = Mock(spec=Request) + request.headers = headers + + # Act + response = await transport._handle_get_request(request) + + # Assert + assert response.status_code == 404 + assert "Invalid or expired session" in response.body.decode() + + # Verify stream creation was never attempted + transport._stream_manager.create_stream.assert_not_awaited() + + async def test_get_request_stream_creation_failure_returns_500(self, transport): + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Mock stream creation to raise an exception + transport._stream_manager.create_stream.side_effect = Exception( + "Stream creation failed" + ) + + # Create request with valid headers and session + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + + # Act + response = await transport._handle_get_request(request) + + # Assert + assert response.status_code == 500 + assert "Internal server error" in response.body.decode() + + # Verify stream creation was attempted + transport._stream_manager.create_stream.assert_awaited_once_with(client_id) diff --git a/tests/transport/streamable_http/server/test_post_handler.py b/tests/transport/streamable_http/server/test_post_handler.py new file mode 100644 index 0000000..0aae123 --- /dev/null +++ b/tests/transport/streamable_http/server/test_post_handler.py @@ -0,0 +1,498 @@ +"""Tests for HttpServerTransport POST request handler.""" + +import json +from unittest.mock import AsyncMock, Mock + +import pytest +from starlette.requests import Request +from starlette.responses import Response, StreamingResponse + +from conduit.protocol.base import PROTOCOL_VERSION +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestPostRequestProcessing: + """Test POST request processing: headers, JSON parsing, routing.""" + + @pytest.fixture + def transport(self): + """Create transport instance with mocked stream manager.""" + # Arrange + transport = HttpServerTransport() + transport._stream_manager = AsyncMock() + return transport + + @pytest.fixture + def mock_stream(self): + """Create a mock stream object.""" + stream = Mock() + stream.stream_id = "test-stream-123" + + async def mock_event_generator(): + yield "data: test\n\n" + + stream.event_generator = Mock(return_value=mock_event_generator()) + return stream + + async def test_mcp_request_creates_stream(self, transport, mock_stream): + """Test POST with MCP request creates request stream.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Mock successful stream creation + transport._stream_manager.create_stream = AsyncMock(return_value=mock_stream) + + # Create MCP request message (has method + id) + message_data = {"jsonrpc": "2.0", "method": "tools/list", "id": "req-123"} + + # Create request with proper headers and body + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + # Verify it's a streaming response + assert isinstance(response, StreamingResponse) + assert response.media_type == "text/event-stream" + + # Verify key headers + assert response.headers["Mcp-Session-Id"] == session_id + assert response.headers["MCP-Protocol-Version"] == PROTOCOL_VERSION + + # Verify request stream creation + transport._stream_manager.create_stream.assert_awaited_once_with( + client_id, "req-123" + ) + + # Verify message was queued + assert not transport._message_queue.empty() + queued_message = await transport._message_queue.get() + assert queued_message.client_id == client_id + assert queued_message.payload == message_data + + async def test_mcp_notification_returns_202(self, transport): + """Test POST with notification returns 202 response.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create notification message (has method but no id) + message_data = { + "jsonrpc": "2.0", + "method": "notifications/progress", + } + + # Create request with proper headers and body + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + # Verify it's a 202 response (not 200, streaming) + assert isinstance(response, Response) + assert response.status_code == 202 + + # Verify key headers + assert response.headers["Mcp-Session-Id"] == session_id + assert response.headers["MCP-Protocol-Version"] == PROTOCOL_VERSION + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message was still queued + assert not transport._message_queue.empty() + queued_message = await transport._message_queue.get() + assert queued_message.client_id == client_id + assert queued_message.payload == message_data + + async def test_missing_protocol_version_returns_400(self, transport): + # Arrange + client_id, session_id = transport._session_manager.create_session() + + message_data = {"jsonrpc": "2.0", "method": "tools/list", "id": "req-123"} + + # Create request with missing MCP-Protocol-Version header + headers = { + # Missing: "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 400 + assert "Missing MCP-Protocol-Version header" in response.body.decode() + + # Verify early exit - no JSON parsing or stream creation + request.json.assert_not_awaited() + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + async def test_malformed_json_returns_400(self, transport): + """Test POST with malformed JSON returns 400.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create request with valid headers but malformed JSON body + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + # Mock request.json() to raise JSONDecodeError + request.json = AsyncMock( + side_effect=json.JSONDecodeError("Expecting value", "doc", 0) + ) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 400 + assert "Invalid JSON" in response.body.decode() + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + async def test_non_dict_json_returns_400(self, transport): + """Test POST with non-dict JSON returns 400.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create request with valid headers but non-dict JSON body + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + # Valid JSON but not a dict + request.json = AsyncMock(return_value=["valid", "json", "array"]) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 400 + assert "Invalid JSON" in response.body.decode() + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + async def test_invalid_jsonrpc_message_returns_400(self, transport): + """Test POST with invalid JSON-RPC message returns 400.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create message that's valid JSON dict but invalid JSON-RPC + message_data = { + "not_jsonrpc": "2.0", + "invalid": "structure", + # Missing required jsonrpc field, method/result, etc. + } + + # Create request with valid headers and invalid JSON-RPC body + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 400 + assert "Invalid JSON-RPC message" in response.body.decode() + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + +class TestSessionValidation: + """Test session validation logic in POST handler.""" + + @pytest.fixture + def transport(self): + """Create transport instance with mocked stream manager.""" + # Arrange + transport = HttpServerTransport() + transport._stream_manager = AsyncMock() + return transport + + async def test_initialize_request_without_session_succeeds(self, transport): + """Test POST initialize request without session ID succeeds.""" + # Arrange + # Create dummy initialize request (no session ID should be provided) + message_data = { + "jsonrpc": "2.0", + "method": "initialize", + "id": "init-123", + "params": {"protocolVersion": PROTOCOL_VERSION}, + } + + # Create request with valid headers but NO session ID + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + # Note: No Mcp-Session-Id header for initialize + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Mock stream creation for the successful case + mock_stream = Mock() + + async def mock_event_generator(): + yield "data: test\n\n" + + mock_stream.event_generator = Mock(return_value=mock_event_generator()) + transport._stream_manager.create_stream = AsyncMock(return_value=mock_stream) + + # Act + response = await transport._handle_post_request(request) + + # Assert + # Should succeed and create a stream (initialize is an MCP request) + assert isinstance(response, StreamingResponse) + assert response.media_type == "text/event-stream" + + # Should have created a new session + assert "Mcp-Session-Id" in response.headers + new_session_id = response.headers["Mcp-Session-Id"] + + # Verify stream creation with new client + transport._stream_manager.create_stream.assert_awaited_once() + + # Verify message was queued with new client_id + assert not transport._message_queue.empty() + queued_message = await transport._message_queue.get() + assert queued_message.payload == message_data + + async def test_initialize_request_with_session_returns_400(self, transport): + """Test POST initialize request with session ID returns 400.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Create dummy initialize request (should NOT have session ID) + message_data = { + "jsonrpc": "2.0", + "method": "initialize", + "id": "init-123", + "params": {"protocolVersion": PROTOCOL_VERSION}, + } + + # Create request with session ID (this is the error) + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, # This should NOT be here for initialize + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 400 + assert ( + "Initialize request must not include session ID" in response.body.decode() + ) + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + async def test_non_initialize_request_missing_session_returns_400(self, transport): + # Arrange + # Create non-initialize request (requires session ID) + message_data = { + "jsonrpc": "2.0", + "method": "tools/list", + "id": "req-123", + } + + # Create request without session ID (this is the error) + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + # Note: Missing Mcp-Session-Id header + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 400 + assert "Missing Mcp-Session-Id header" in response.body.decode() + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + async def test_non_initialize_request_invalid_session_returns_404(self, transport): + # Arrange + # Create non-initialize request + message_data = { + "jsonrpc": "2.0", + "method": "tools/list", + "id": "req-123", + } + + # Create request with non-existent session ID + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": "non-existent-session-456", + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 404 + assert "Invalid or expired Mcp-Session-Id" in response.body.decode() + + # Verify no stream creation attempted + transport._stream_manager.create_stream.assert_not_awaited() + + # Verify message queue is empty (never got that far) + assert transport._message_queue.empty() + + async def test_initialize_creates_new_client_and_session(self, transport): + """Test POST initialize request creates new client and session.""" + # Arrange + message_data = { + "jsonrpc": "2.0", + "method": "initialize", + "id": "init-123", + "params": {"protocolVersion": PROTOCOL_VERSION}, + } + + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + # No session ID for initialize + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Mock stream creation + mock_stream = Mock() + + async def mock_event_generator(): + yield "data: test\n\n" + + mock_stream.event_generator = Mock(return_value=mock_event_generator()) + transport._stream_manager.create_stream = AsyncMock(return_value=mock_stream) + + # Act + response = await transport._handle_post_request(request) + + # Assert + # Verify new session was created and returned + assert "Mcp-Session-Id" in response.headers + new_session_id = response.headers["Mcp-Session-Id"] + + # Verify the session actually exists in the session manager + assert transport._session_manager.session_exists(new_session_id) + + # Verify we can get the client ID for this session + client_id = transport._session_manager.get_client_id(new_session_id) + assert client_id is not None + + async def test_post_request_stream_creation_failure_returns_500(self, transport): + """Test POST MCP request returns 500 when stream creation fails.""" + # Arrange + client_id, session_id = transport._session_manager.create_session() + + # Mock stream creation to raise an exception + transport._stream_manager.create_stream.side_effect = Exception( + "Stream creation failed" + ) + + message_data = {"jsonrpc": "2.0", "method": "tools/list", "id": "req-123"} + + headers = { + "MCP-Protocol-Version": PROTOCOL_VERSION, + "Accept": "text/event-stream, application/json", + "Origin": "https://example.com", + "Mcp-Session-Id": session_id, + } + request = Mock(spec=Request) + request.headers = headers + request.json = AsyncMock(return_value=message_data) + + # Act + response = await transport._handle_post_request(request) + + # Assert + assert response.status_code == 500 + assert "Internal server error" in response.body.decode() + + # Verify stream creation was attempted + transport._stream_manager.create_stream.assert_awaited_once_with( + client_id, "req-123" + ) diff --git a/tests/transport/streamable_http/server/test_send.py b/tests/transport/streamable_http/server/test_send.py new file mode 100644 index 0000000..b4df3e4 --- /dev/null +++ b/tests/transport/streamable_http/server/test_send.py @@ -0,0 +1,118 @@ +"""Tests for HttpServerTransport.send method.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from conduit.transport.server import TransportContext +from conduit.transport.streamable_http.server.transport import HttpServerTransport + + +class TestHttpServerTransportSend: + """Orchestration tests for HttpServerTransport.send.""" + + @pytest.fixture + def transport(self): + """Create transport instance with mocked dependencies.""" + # Arrange + transport = HttpServerTransport() + transport._session_manager = Mock() + transport._stream_manager = AsyncMock() + return transport + + async def test_send_successful_with_originating_request_id(self, transport): + """Test successful message send with originating request ID.""" + # Arrange + client_id = "test-client-123" + message = {"jsonrpc": "2.0", "method": "test", "id": 1} + transport_context = TransportContext(originating_request_id="req-456") + + transport._session_manager.get_session_id.return_value = "session-789" + transport._stream_manager.send_to_existing_stream.return_value = True + + # Act + await transport.send(client_id, message, transport_context) + + # Assert + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._stream_manager.send_to_existing_stream.assert_awaited_once_with( + client_id, message, "req-456" + ) + + async def test_send_successful_without_transport_context(self, transport): + """Test successful message send without transport context.""" + # Arrange + client_id = "test-client-123" + message = {"jsonrpc": "2.0", "method": "test", "id": 1} + + transport._session_manager.get_session_id.return_value = "session-789" + transport._stream_manager.send_to_existing_stream.return_value = True + + # Act + await transport.send(client_id, message) + + # Assert + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._stream_manager.send_to_existing_stream.assert_awaited_once_with( + client_id, message, None + ) + + async def test_send_raises_value_error_for_nonexistent_client(self, transport): + """Test that send raises ValueError when client doesn't exist.""" + # Arrange + client_id = "nonexistent-client" + message = {"jsonrpc": "2.0", "method": "test", "id": 1} + + transport._session_manager.get_session_id.return_value = None + + # Act & Assert + with pytest.raises(ValueError): + await transport.send(client_id, message) + + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._stream_manager.send_to_existing_stream.assert_not_awaited() + + async def test_send_raises_connection_error_when_no_streams_available( + self, transport + ): + # Arrange + client_id = "test-client-123" + message = {"jsonrpc": "2.0", "method": "test", "id": 1} + + transport._session_manager.get_session_id.return_value = "session-789" + transport._stream_manager.send_to_existing_stream.return_value = False + + # Act & Assert + with pytest.raises(ConnectionError): + await transport.send(client_id, message) + + transport._session_manager.get_session_id.assert_called_once_with(client_id) + transport._stream_manager.send_to_existing_stream.assert_awaited_once_with( + client_id, message, None + ) + + async def test_send_with_different_message_types(self, transport): + # Arrange + client_id = "test-client-123" + transport._session_manager.get_session_id.return_value = "session-789" + transport._stream_manager.send_to_existing_stream.return_value = True + + test_messages = [ + # Request + {"jsonrpc": "2.0", "method": "test_request", "id": 1}, + # Response + {"jsonrpc": "2.0", "result": {"data": "test"}, "id": 1}, + # Notification + {"jsonrpc": "2.0", "method": "test_notification"}, + # Error response + {"jsonrpc": "2.0", "error": {"code": -1, "message": "Test error"}, "id": 1}, + ] + + # Act & Assert + for message in test_messages: + await transport.send(client_id, message) + + # Verify all messages were sent + assert transport._stream_manager.send_to_existing_stream.await_count == len( + test_messages + ) diff --git a/tests/transport/streamable_http/server/test_session_manager.py b/tests/transport/streamable_http/server/test_session_manager.py new file mode 100644 index 0000000..1927655 --- /dev/null +++ b/tests/transport/streamable_http/server/test_session_manager.py @@ -0,0 +1,176 @@ +import uuid + +from conduit.transport.streamable_http.server.session_manager import SessionManager + + +class TestSessionManager: + def test_empty_manager_state(self): + # Arrange + manager = SessionManager() + fake_id = str(uuid.uuid4()) + + # Act & Assert + assert manager.get_client_id(fake_id) is None + assert manager.get_session_id(fake_id) is None + assert manager.session_exists(fake_id) is False + assert manager.terminate_session(fake_id) is False + + # Should not raise errors + manager.terminate_all_sessions() + + def test_create_session_happy_path(self): + # Arrange + manager = SessionManager() + + # Act + client_id, session_id = manager.create_session() + + # Assert + assert client_id is not None + assert session_id is not None + assert isinstance(client_id, str) + assert isinstance(session_id, str) + + # Verify both IDs are valid UUIDs + uuid.UUID(client_id) # Will raise ValueError if invalid + uuid.UUID(session_id) # Will raise ValueError if invalid + + # Verify bidirectional lookup works + assert manager.get_client_id(session_id) == client_id + assert manager.get_session_id(client_id) == session_id + assert manager.session_exists(session_id) is True + + def test_create_multiple_sessions(self): + # Arrange + manager = SessionManager() + + # Act + client_id1, session_id1 = manager.create_session() + client_id2, session_id2 = manager.create_session() + client_id3, session_id3 = manager.create_session() + + # Assert + # All IDs should be unique + assert client_id1 != client_id2 != client_id3 + assert session_id1 != session_id2 != session_id3 + + # All sessions should exist and map correctly + assert manager.get_client_id(session_id1) == client_id1 + assert manager.get_client_id(session_id2) == client_id2 + assert manager.get_client_id(session_id3) == client_id3 + + assert manager.get_session_id(client_id1) == session_id1 + assert manager.get_session_id(client_id2) == session_id2 + assert manager.get_session_id(client_id3) == session_id3 + + def test_get_session_id(self): + # Arrange + manager = SessionManager() + client_id, session_id = manager.create_session() + fake_client_id = str(uuid.uuid4()) + + # Act + real_session = manager.get_session_id(client_id) + fake_session = manager.get_session_id(fake_client_id) + + # Assert + assert real_session == session_id + assert fake_session is None + + def test_session_exists(self): + # Arrange + manager = SessionManager() + _, session_id = manager.create_session() + + # Act + real_session = manager.session_exists(session_id) + fake_session = manager.session_exists(str(uuid.uuid4())) + + # Assert + assert real_session is True + assert fake_session is False + + def test_terminate_session_existing(self): + # Arrange + manager = SessionManager() + client_id, session_id = manager.create_session() + + # Act + result = manager.terminate_session(session_id) + + # Assert + assert result is True + assert manager.session_exists(session_id) is False + assert manager.get_client_id(session_id) is None + assert manager.get_session_id(client_id) is None + + def test_terminate_session_nonexistent(self): + # Arrange + manager = SessionManager() + fake_session_id = str(uuid.uuid4()) + + # Act + result = manager.terminate_session(fake_session_id) + + # Assert: should return False + assert result is False + + def test_terminate_session_leaves_others_intact(self): + # Arrange + manager = SessionManager() + client_id1, session_id1 = manager.create_session() + client_id2, session_id2 = manager.create_session() + client_id3, session_id3 = manager.create_session() + + # Act + result = manager.terminate_session(session_id2) + + # Assert + assert result is True + + # Session 2 should be gone + assert manager.session_exists(session_id2) is False + assert manager.get_client_id(session_id2) is None + assert manager.get_session_id(client_id2) is None + + # Sessions 1 and 3 should remain + assert manager.session_exists(session_id1) is True + assert manager.session_exists(session_id3) is True + assert manager.get_client_id(session_id1) == client_id1 + assert manager.get_client_id(session_id3) == client_id3 + assert manager.get_session_id(client_id1) == session_id1 + assert manager.get_session_id(client_id3) == session_id3 + + def test_terminate_all_sessions_empty_manager(self): + # Arrange + manager = SessionManager() + + # Act + manager.terminate_all_sessions() + + # Assert + # Should not raise any errors + assert True + + def test_terminate_all_sessions_with_sessions(self): + # Arrange + manager = SessionManager() + client_id1, session_id1 = manager.create_session() + client_id2, session_id2 = manager.create_session() + client_id3, session_id3 = manager.create_session() + + # Act + manager.terminate_all_sessions() + + # Assert + assert manager.session_exists(session_id1) is False + assert manager.session_exists(session_id2) is False + assert manager.session_exists(session_id3) is False + + assert manager.get_client_id(session_id1) is None + assert manager.get_client_id(session_id2) is None + assert manager.get_client_id(session_id3) is None + + assert manager.get_session_id(client_id1) is None + assert manager.get_session_id(client_id2) is None + assert manager.get_session_id(client_id3) is None diff --git a/tests/transport/streamable_http/server/test_sse_stream.py b/tests/transport/streamable_http/server/test_sse_stream.py new file mode 100644 index 0000000..1d4d574 --- /dev/null +++ b/tests/transport/streamable_http/server/test_sse_stream.py @@ -0,0 +1,105 @@ +import json + +import pytest + +from conduit.transport.streamable_http.server.streams import SSEStream + + +class TestSSEStream: + async def test_send_message_yields_event(self): + # Arrange + stream = SSEStream("test-stream", "client-123", "request-456") + test_message = {"id": "req-1", "method": "test", "params": {}} + + # Act + await stream.send_message(test_message) + event_gen = stream.event_generator() + + # Get the first event + event = await event_gen.__anext__() + + # Assert + assert event is not None + assert isinstance(event, str) + assert event.startswith("data: ") + assert event.endswith("\n\n") + + # Verify the JSON content + event_data = event[6:-2] # Strip "data: " and "\n\n" + parsed_message = json.loads(event_data) + assert parsed_message == test_message + + # Verify stream properties + assert stream.stream_id == "test-stream" + assert stream.client_id == "client-123" + assert stream.request_id == "request-456" + + async def test_explicit_close_stops_event_generator(self): + """Test that explicitly closing the stream stops the event generator.""" + # Arrange + stream = SSEStream("test-stream", "client-123", "request-456") + test_message = {"id": "req-1", "method": "test", "params": {}} + + # Act + await stream.send_message(test_message) + await stream.close() # Explicit close + + event_gen = stream.event_generator() + + # Get the first event (our test message) + event = await event_gen.__anext__() + + # Assert first event is our message + event_data = event[6:-2] # Strip "data: " and "\n\n" + parsed_message = json.loads(event_data) + assert parsed_message == test_message + + # Generator should stop after close sentinel + with pytest.raises(StopAsyncIteration): + await event_gen.__anext__() + + async def test_response_message_auto_closes_event_generator(self): + """Test that sending a response message auto-closes the generator.""" + # Arrange + stream = SSEStream("test-stream", "client-123", "request-456") + response_message = {"id": "req-1", "result": {"success": True}} + + # Act + await stream.send_message(response_message) + event_gen = stream.event_generator() + + # Get the response event + event = await event_gen.__anext__() + + # Assert response was sent + event_data = event[6:-2] + parsed_message = json.loads(event_data) + assert parsed_message == response_message + + # Generator should auto-close after response + with pytest.raises(StopAsyncIteration): + await event_gen.__anext__() + + async def test_is_response_detection(self): + """Test response detection logic.""" + # Arrange + stream = SSEStream("test-stream", "client-123", "request-456") + + # Act & Assert + # Valid responses + assert stream.is_response({"id": "req-1", "result": {"data": "test"}}) is True + assert ( + stream.is_response( + {"id": "req-1", "error": {"code": -1, "message": "fail"}} + ) + is True + ) + + # Invalid responses (not responses) + assert stream.is_response({"method": "test", "params": {}}) is False + assert stream.is_response({"id": "req-1"}) is False # Missing result/error + assert ( + stream.is_response({"id": "req-1", "result": {}, "error": {}}) is False + ) # Both result and error + assert stream.is_response({"result": {"data": "test"}}) is False # Missing id + assert stream.is_response({"id": True, "result": {}}) is False # Boolean id diff --git a/tests/transport/streamable_http/server/test_stream_manager.py b/tests/transport/streamable_http/server/test_stream_manager.py new file mode 100644 index 0000000..51dec00 --- /dev/null +++ b/tests/transport/streamable_http/server/test_stream_manager.py @@ -0,0 +1,128 @@ +import json + +import pytest + +from conduit.transport.streamable_http.server.streams import StreamManager + + +class TestStreamManager: + async def test_create_stream_and_send_message_happy_path(self): + """Test creating a stream and sending a message successfully.""" + # Arrange + manager = StreamManager() + client_id = "client-123" + request_id = "req-456" + test_message = {"id": "msg-1", "method": "test", "params": {"data": "hello"}} + + # Act + stream = await manager.create_stream(client_id, request_id) + success = await manager.send_to_existing_stream( + client_id, test_message, originating_request_id=request_id + ) + + # Assert + assert success is True + assert stream is not None + assert stream.client_id == client_id + assert stream.request_id == request_id + assert isinstance(stream.stream_id, str) + + # Verify the message was queued by getting it from the event generator + event_gen = stream.event_generator() + event = await event_gen.__anext__() + + assert event.startswith("data: ") + assert event.endswith("\n\n") + + # Verify the JSON content + event_data = event[6:-2] # Strip "data: " and "\n\n" + parsed_message = json.loads(event_data) + assert parsed_message == test_message + + async def test_cleanup_client_streams(self): + """Test cleaning up all streams for a client.""" + # Arrange + manager = StreamManager() + client_id = "client-123" + other_client_id = "client-456" + + # Create multiple streams for the target client + stream1 = await manager.create_stream(client_id, "req-1") + stream2 = await manager.create_stream(client_id, "req-2") + stream3 = await manager.create_stream(client_id) # No request_id + + # Create a stream for another client (should not be affected) + other_stream = await manager.create_stream(other_client_id, "req-3") + + # Act + await manager.cleanup_client_streams(client_id) + + # Assert + # Target client streams should be cleaned up + assert manager.get_stream_by_id(stream1.stream_id) is None + assert manager.get_stream_by_id(stream2.stream_id) is None + assert manager.get_stream_by_id(stream3.stream_id) is None + + # Other client stream should remain + assert manager.get_stream_by_id(other_stream.stream_id) is not None + + # Verify streams are actually closed by trying to get events + # (should get StopAsyncIteration immediately due to close sentinel) + with pytest.raises(StopAsyncIteration): + event_gen = stream1.event_generator() + await event_gen.__anext__() + + async def test_cleanup_nonexistent_client(self): + """Test cleaning up streams for a client that doesn't exist.""" + # Arrange + manager = StreamManager() + fake_client_id = "nonexistent-client" + + # Act & Assert - should not raise any errors + await manager.cleanup_client_streams(fake_client_id) + + async def test_close_all_streams(self): + """Test closing all streams in the manager.""" + # Arrange + manager = StreamManager() + + # Create streams for multiple clients + stream1 = await manager.create_stream("client-1", "req-1") + stream2 = await manager.create_stream("client-2", "req-2") + stream3 = await manager.create_stream("client-1") # Another for client-1 + + # Act + await manager.close_all_streams() + + # Assert + # All streams should be cleaned up + assert manager.get_stream_by_id(stream1.stream_id) is None + assert manager.get_stream_by_id(stream2.stream_id) is None + assert manager.get_stream_by_id(stream3.stream_id) is None + + # All streams should be closed + for stream in [stream1, stream2, stream3]: + with pytest.raises(StopAsyncIteration): + event_gen = stream.event_generator() + await event_gen.__anext__() + + async def test_send_to_nonexistent_stream(self): + """Test sending to streams that don't exist.""" + # Arrange + manager = StreamManager() + client_id = "client-123" + message = {"id": "msg-1", "method": "test"} + + # Act & Assert + # No streams exist at all + result = await manager.send_to_existing_stream(client_id, message, "req-1") + assert result is False + + # Stream exists but wrong request_id + await manager.create_stream(client_id, "req-2") + result = await manager.send_to_existing_stream(client_id, message, "req-1") + assert result is False + + # Client has no streams for fallback case + result = await manager.send_to_existing_stream("nonexistent-client", message) + assert result is False