diff --git a/src/conduit/client/session.py b/src/conduit/client/session.py index 8eeb490..d1752d2 100644 --- a/src/conduit/client/session.py +++ b/src/conduit/client/session.py @@ -172,6 +172,8 @@ async def disconnect_all_servers(self) -> None: for server_id in self.server_manager.get_server_ids(): await self.disconnect_server(server_id) + await self.transport.close() + # ================================ # Initialization # ================================ diff --git a/src/conduit/server/session.py b/src/conduit/server/session.py index 360b2a2..6535837 100644 --- a/src/conduit/server/session.py +++ b/src/conduit/server/session.py @@ -172,6 +172,8 @@ async def disconnect_all_clients(self) -> None: for client_id in self.client_manager.get_client_ids(): await self.disconnect_client(client_id) + await self.transport.close() + # ================================ # Initialization # ================================ diff --git a/src/conduit/transport/client.py b/src/conduit/transport/client.py index 3e3daba..16eab76 100644 --- a/src/conduit/transport/client.py +++ b/src/conduit/transport/client.py @@ -68,3 +68,12 @@ async def disconnect_server(self, server_id: str) -> None: server_id: Server connection ID to disconnect """ ... + + @abstractmethod + async def close(self) -> None: + """Close the transport and clean up all resources. + + Should be called when the transport is no longer needed. + Safe to call multiple times - subsequent calls are no-ops. + """ + ... diff --git a/src/conduit/transport/server.py b/src/conduit/transport/server.py index 2a50374..f2b163c 100644 --- a/src/conduit/transport/server.py +++ b/src/conduit/transport/server.py @@ -66,3 +66,12 @@ async def disconnect_client(self, client_id: str) -> None: client_id: Client connection ID to disconnect """ ... + + @abstractmethod + async def close(self) -> None: + """Close the transport and clean up all resources. + + Should be called when the transport is no longer needed. + Safe to call multiple times - subsequent calls are no-ops. + """ + ... diff --git a/src/conduit/transport/stdio/client.py b/src/conduit/transport/stdio/client.py index 974bc77..eb151b1 100644 --- a/src/conduit/transport/stdio/client.py +++ b/src/conduit/transport/stdio/client.py @@ -345,3 +345,8 @@ async def _shutdown_server_process( logger.error(f"Error during shutdown of server '{server_id}': {e}") finally: server_process.process = None + + async def close(self) -> None: + """Close the transport and clean up all resources.""" + for server_id in list(self._servers): + await self.disconnect_server(server_id) diff --git a/src/conduit/transport/stdio/server.py b/src/conduit/transport/stdio/server.py index 1439a2b..5e2646b 100644 --- a/src/conduit/transport/stdio/server.py +++ b/src/conduit/transport/stdio/server.py @@ -101,3 +101,12 @@ async def disconnect_client(self, client_id: str) -> None: pass sys.exit(0) + + async def close(self) -> None: + """Close the transport and clean up all resources.""" + try: + sys.stdout.close() + except Exception: + pass + + sys.exit(0) diff --git a/src/conduit/transport/streamable_http/client.py b/src/conduit/transport/streamable_http/client.py index da9e19f..670d755 100644 --- a/src/conduit/transport/streamable_http/client.py +++ b/src/conduit/transport/streamable_http/client.py @@ -8,6 +8,7 @@ from conduit.protocol.base import PROTOCOL_VERSION from conduit.transport.client import ClientTransport, ServerMessage +from conduit.transport.streamable_http.client_stream_manager import ClientStreamManager logger = logging.getLogger(__name__) @@ -15,8 +16,11 @@ class StreamableHttpClientTransport(ClientTransport): """HTTP client transport supporting multiple server connections. - Phase 1: Basic HTTP POST/JSON response with session management. - Future phases will add SSE streams, resumability, etc. + Implements the Streamable HTTP transport specification, supporting: + - HTTP POST for sending messages to servers + - SSE streams for receiving messages from servers + - Session management with Mcp-Session-Id headers + - Multiple concurrent server connections """ def __init__(self) -> None: @@ -25,10 +29,15 @@ def __init__(self) -> None: self._sessions: dict[str, str] = {} # server_id -> session_id self._http_client = httpx.AsyncClient() self._message_queue: asyncio.Queue[ServerMessage] = asyncio.Queue() + # server_id -> stream manager + self._stream_managers: dict[str, ClientStreamManager] = {} async def add_server(self, server_id: str, connection_info: dict[str, Any]) -> None: """Register HTTP server endpoint. + Stores the server configuration without establishing a connection. + Connection will be established when first message is sent. + Args: server_id: Unique identifier for this server connection connection_info: HTTP connection details @@ -59,7 +68,8 @@ async def add_server(self, server_id: str, connection_info: dict[str, Any]) -> N async def send(self, server_id: str, message: dict[str, Any]) -> None: """Send message to server via HTTP POST. - Phase 1: Always expects immediate JSON response. + All JSON-RPC messages are sent as HTTP POST requests to the MCP endpoint. + Handles both immediate JSON responses and SSE streams according to the spec. Args: server_id: Target server connection ID @@ -74,18 +84,7 @@ async def send(self, server_id: str, message: dict[str, Any]) -> None: server_config = self._servers[server_id] endpoint = server_config["endpoint"] - - # Build headers - headers = { - "Content-Type": "application/json", - "Accept": "application/json, text/event-stream", - "MCP-Protocol-Version": PROTOCOL_VERSION, - **server_config["headers"], - } - - # Add session ID if we have one - if server_id in self._sessions: - headers["Mcp-Session-Id"] = self._sessions[server_id] + headers = self._build_headers(server_id, server_config) try: response = await self._http_client.post( @@ -95,13 +94,39 @@ async def send(self, server_id: str, message: dict[str, Any]) -> None: timeout=30.0, # TODO: Make configurable ) - # Handle session management + # Handle session management for initialize responses await self._handle_session_management(server_id, message, response) - # Handle response - if response.status_code == 200: + # Handle different response types based on content-type + await self._handle_response(server_id, response) + + except httpx.RequestError as e: + raise ConnectionError( + f"HTTP request failed for server '{server_id}': {e}" + ) from e + + async def _handle_response(self, server_id: str, response: httpx.Response) -> None: + """Handle the HTTP response based on content type and status. + + According to the spec: + - 200 with application/json: Single JSON response + - 200 with text/event-stream: SSE stream + - 202: Accepted (for notifications/responses) + - 404: Session expired (if we sent a session ID) + - Other errors: Raise ConnectionError + + Args: + server_id: Server ID to handle response for + response: The HTTP response from the server + + Raises: + ConnectionError: If the session expired or other HTTP errors + """ + if response.status_code == 200: + content_type = response.headers.get("content-type", "") + + if "application/json" in content_type: response_data = response.json() - # Put response in queue for server_messages() server_message = ServerMessage( server_id=server_id, payload=response_data, @@ -109,25 +134,62 @@ async def send(self, server_id: str, message: dict[str, Any]) -> None: ) await self._message_queue.put(server_message) - elif response.status_code == 404: - # Session expired - clear session ID + elif "text/event-stream" in content_type: + # Get or create stream manager for this server + if server_id not in self._stream_managers: + self._stream_managers[server_id] = ClientStreamManager( + server_id, self._http_client + ) + + stream_manager = self._stream_managers[server_id] + await stream_manager.create_response_stream( + response, self._message_queue + ) + + elif response.status_code == 202: + logger.debug(f"Server '{server_id}' accepted message (202)") + + elif response.status_code == 404: + # Check if we sent a session ID - if so, this means session expired + request_had_session = "Mcp-Session-Id" in response.request.headers + + if request_had_session: + # Session expired - clear session ID per spec if server_id in self._sessions: + expired_session = self._sessions[server_id] del self._sessions[server_id] - raise ConnectionError(f"Session expired for server '{server_id}'") - + logger.info( + f"Session '{expired_session}' expired for server '{server_id}'" + " - cleared session ID" + ) + + raise ConnectionError( + f"Session expired for server '{server_id}'. " + "Must re-initialize with a new InitializeRequest." + ) else: - response.raise_for_status() + # Regular 404 - not session related + raise ConnectionError( + f"Server '{server_id}' returned 404: {response.text or 'Not Found'}" + ) - except httpx.RequestError as e: - raise ConnectionError( - f"HTTP request failed for server '{server_id}': {e}" - ) from e + else: + # Other HTTP errors + response.raise_for_status() async def _handle_session_management( self, server_id: str, message: dict[str, Any], response: httpx.Response ) -> None: - """Handle session ID extraction from InitializeResult.""" - # Check if this was an InitializeRequest and response contains session ID + """Handle session ID extraction from InitializeResult responses. + + According to the spec, servers MAY include an Mcp-Session-Id header + in the response to an initialize request to establish a session. + + Args: + server_id: Server ID to handle session management for + message: The message that was sent to the server + response: The response from the server + """ if ( message.get("method") == "initialize" and response.status_code == 200 @@ -135,59 +197,244 @@ async def _handle_session_management( ): session_id = response.headers["Mcp-Session-Id"] self._sessions[server_id] = session_id - logger.debug(f"Stored session ID for server '{server_id}': {session_id}") - - def server_messages(self) -> AsyncIterator[ServerMessage]: - """Stream of messages from all servers. - - Phase 1: Simple queue-based approach. - """ - return self._message_queue_iterator() - - async def _message_queue_iterator(self) -> AsyncIterator[ServerMessage]: - """Async iterator that yields messages from the queue.""" - while True: - try: - message = await self._message_queue.get() - yield message - except Exception as e: - logger.error(f"Error reading from message queue: {e}") - break + logger.debug(f"Established session for server '{server_id}': {session_id}") async def disconnect_server(self, server_id: str) -> None: """Disconnect from specific server. + Attempts graceful session termination via DELETE request if we have a session, + cancels all active SSE streams, then cleans up all local state. + Safe to call multiple times. + Args: server_id: Server connection ID to disconnect """ if server_id not in self._servers: return + # Cancel all streams for this server + if server_id in self._stream_managers: + self._stream_managers[server_id].cancel_all_streams() + del self._stream_managers[server_id] + # If we have a session, try to terminate it gracefully if server_id in self._sessions: try: session_id = self._sessions[server_id] endpoint = self._servers[server_id]["endpoint"] - headers = {"Mcp-Session-Id": session_id} + + # Build headers for DELETE request + headers = { + "Mcp-Session-Id": session_id, + "MCP-Protocol-Version": PROTOCOL_VERSION, + } + # Add any custom headers from server config + headers.update(self._servers[server_id]["headers"]) # Attempt graceful session termination - await self._http_client.delete(endpoint, headers=headers, timeout=5.0) + response = await self._http_client.delete( + endpoint, headers=headers, timeout=5.0 + ) + + if response.status_code == 200: + logger.debug( + f"Successfully terminated session '{session_id}' for " + f"server '{server_id}'" + ) + elif response.status_code == 405: + logger.debug( + f"Server '{server_id}' does not support session termination " + "(405)" + ) + else: + logger.warning( + f"Unexpected response {response.status_code} when terminating " + f"session for '{server_id}'" + ) except Exception as e: logger.debug( f"Failed to gracefully terminate session for '{server_id}': {e}" ) finally: + # Always clean up session state del self._sessions[server_id] # Remove server configuration del self._servers[server_id] logger.debug(f"Disconnected from server '{server_id}'") - async def __aenter__(self): - """Async context manager entry.""" - return self + async def close(self) -> None: + """Close the HTTP client and clean up all resources. + + Cancels all active streams and closes the underlying HTTP client. + Safe to call multiple times. + """ + # Cancel all active streams across all servers + for stream_manager in list(self._stream_managers.values()): + stream_manager.cancel_all_streams() + self._stream_managers.clear() + + # Clear all state + self._servers.clear() + self._sessions.clear() + + # Close the HTTP client + if not self._http_client.is_closed: + await self._http_client.aclose() + logger.debug("HTTP client closed") + + def _build_headers( + self, server_id: str, server_config: dict[str, Any] + ) -> dict[str, str]: + """Build HTTP headers for requests to the server. + + Constructs headers according to the Streamable HTTP spec: + - Content-Type: application/json (for POST body) + - Accept: application/json, text/event-stream (support both response types) + - MCP-Protocol-Version: current protocol version + - Mcp-Session-Id: session ID if we have one for this server + - Any custom headers from server config + + Args: + server_id: Server ID to build headers for + server_config: Server configuration containing custom headers + + Returns: + Complete headers dict for the HTTP request + """ + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "MCP-Protocol-Version": PROTOCOL_VERSION, + } + + # Add session ID if we have one for this server + if server_id in self._sessions: + headers["Mcp-Session-Id"] = self._sessions[server_id] + + # Add any custom headers from server config + headers.update(server_config["headers"]) + + return headers + + async def start_server_stream(self, server_id: str) -> None: + """Start a server-initiated message stream via HTTP GET. + + Opens an SSE stream that allows the server to send requests and + notifications without the client first sending a message. + + Args: + server_id: Server to start stream for + + Raises: + ValueError: If server_id is not registered + ConnectionError: If the server doesn't support server streams (405) + or other HTTP errors occur + """ + if server_id not in self._servers: + raise ValueError(f"Server '{server_id}' is not registered") + + server_config = self._servers[server_id] + endpoint = server_config["endpoint"] - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit - cleanup HTTP client.""" - await self._http_client.aclose() + # Build headers for GET request - only Accept text/event-stream + headers = { + "Accept": "text/event-stream", + "MCP-Protocol-Version": PROTOCOL_VERSION, + } + + # Add session ID if we have one + if server_id in self._sessions: + headers["Mcp-Session-Id"] = self._sessions[server_id] + + # Add custom headers from server config + headers.update(server_config["headers"]) + + try: + # Test if server supports GET streams + response = await self._http_client.get( + endpoint, + headers=headers, + timeout=10.0, # Shorter timeout for initial connection + ) + + if response.status_code == 405: + raise ConnectionError( + f"Server '{server_id}' does not support server-initiated streams " + "(405)" + ) + elif response.status_code == 404: + # Handle session expiry same as in _handle_response + if "Mcp-Session-Id" in headers: + if server_id in self._sessions: + expired_session = self._sessions[server_id] + del self._sessions[server_id] + logger.info( + f"Session '{expired_session}' expired for server " + f"'{server_id}' during GET stream setup - cleared session " + "ID" + ) + raise ConnectionError( + f"Session expired for server '{server_id}'. " + "Must re-initialize with a new InitializeRequest." + ) + else: + raise ConnectionError( + f"Server '{server_id}' returned 404 for GET stream" + ) + elif response.status_code != 200: + response.raise_for_status() + + # Check content type + content_type = response.headers.get("content-type", "") + if "text/event-stream" not in content_type: + raise ConnectionError( + f"Server '{server_id}' returned non-SSE content type: " + f"{content_type}" + ) + + # Create stream manager if needed and start the server stream + if server_id not in self._stream_managers: + self._stream_managers[server_id] = ClientStreamManager( + server_id, self._http_client + ) + + stream_manager = self._stream_managers[server_id] + await stream_manager.create_server_stream( + endpoint, headers, self._message_queue + ) + + logger.debug(f"Started server stream for '{server_id}'") + + except Exception as e: + if isinstance(e, ConnectionError): + raise + raise ConnectionError( + f"Failed to start server stream for '{server_id}': {e}" + ) from e + + def server_messages(self) -> AsyncIterator[ServerMessage]: + """Stream of messages from all servers with explicit server context. + + Yields messages from the internal queue as they arrive from HTTP responses + and SSE streams. This is the main way consumers get messages from servers. + + Yields: + ServerMessage: Message with server ID and metadata + """ + return self._message_queue_iterator() + + async def _message_queue_iterator(self) -> AsyncIterator[ServerMessage]: + """Async iterator that yields messages from the queue. + + Continuously reads from the message queue until the transport is closed. + This runs indefinitely - consumers should break out of the loop when done. + """ + while True: + try: + message = await self._message_queue.get() + yield message + except Exception as e: + logger.error(f"Error reading from message queue: {e}") + break diff --git a/src/conduit/transport/streamable_http/client_stream_manager.py b/src/conduit/transport/streamable_http/client_stream_manager.py new file mode 100644 index 0000000..ff682a9 --- /dev/null +++ b/src/conduit/transport/streamable_http/client_stream_manager.py @@ -0,0 +1,212 @@ +"""Client-side stream management for HTTP transport.""" + +import asyncio +import json +import logging +from typing import Any + +import httpx +from httpx_sse import aconnect_sse + +from conduit.transport.client import ServerMessage + +logger = logging.getLogger(__name__) + + +class ClientStreamManager: + """Manages SSE streams for a single server connection. + + Handles multiple concurrent streams per server, including: + - Request streams (from POST responses) + - Server streams (from GET requests) + - Proper lifecycle management and cleanup + """ + + def __init__(self, server_id: str, http_client: httpx.AsyncClient) -> None: + """Initialize stream manager for a specific server. + + Args: + server_id: ID of the server this manager handles + http_client: HTTP client to use for connections + """ + self.server_id = server_id + self._http_client = http_client + self._active_streams: set[asyncio.Task] = set() + self._stream_counter = 0 + + async def create_response_stream( + self, response: httpx.Response, message_queue: asyncio.Queue[ServerMessage] + ) -> None: + """Create a new SSE stream from an HTTP response. + + Spawns a background task to handle the stream and tracks it for cleanup. + + Args: + response: HTTP response containing the SSE stream + message_queue: Queue to put parsed messages into + """ + self._stream_counter += 1 + stream_id = f"{self.server_id}-response-{self._stream_counter}" + + stream_task = asyncio.create_task( + self._handle_sse_stream(stream_id, response, message_queue), name=stream_id + ) + + self._active_streams.add(stream_task) + stream_task.add_done_callback(self._active_streams.discard) + + logger.debug(f"Created response stream {stream_id}") + + async def create_server_stream( + self, + endpoint: str, + headers: dict[str, str], + message_queue: asyncio.Queue[ServerMessage], + ) -> None: + """Create a new server-initiated SSE stream via GET request. + + Args: + endpoint: Server endpoint URL + headers: Headers to send with GET request + message_queue: Queue to put parsed messages into + """ + self._stream_counter += 1 + stream_id = f"{self.server_id}-server-{self._stream_counter}" + + stream_task = asyncio.create_task( + self._handle_get_stream(stream_id, endpoint, headers, message_queue), + name=stream_id, + ) + + self._active_streams.add(stream_task) + stream_task.add_done_callback(self._active_streams.discard) + + logger.debug(f"Created server stream {stream_id}") + + async def _handle_sse_stream( + self, + stream_id: str, + response: httpx.Response, + message_queue: asyncio.Queue[ServerMessage], + ) -> None: + """Handle an SSE stream from an HTTP response.""" + try: + async with aconnect_sse( + self._http_client, + response.request.method, + str(response.url), + headers=dict(response.request.headers), + ) as event_source: + async for sse_event in event_source.aiter_sse(): + if sse_event.data: + await self._process_sse_event( + stream_id, sse_event, message_queue + ) + + except asyncio.CancelledError: + logger.debug(f"Stream {stream_id} was cancelled") + raise + except Exception as e: + logger.error(f"Stream {stream_id} error: {e}") + finally: + logger.debug(f"Stream {stream_id} closed") + + async def _handle_get_stream( + self, + stream_id: str, + endpoint: str, + headers: dict[str, str], + message_queue: asyncio.Queue[ServerMessage], + ) -> None: + """Handle a server-initiated SSE stream via GET request.""" + try: + async with self._http_client.stream( + "GET", endpoint, headers=headers + ) as response: + if response.status_code != 200: + logger.warning( + f"GET stream {stream_id} failed: {response.status_code}" + ) + return + + if "text/event-stream" not in response.headers.get("content-type", ""): + logger.warning( + f"GET stream {stream_id} not SSE: " + f"{response.headers.get('content-type')}" + ) + return + + async with aconnect_sse( + self._http_client, "GET", endpoint, headers=headers + ) as event_source: + async for sse_event in event_source.aiter_sse(): + if sse_event.data: + await self._process_sse_event( + stream_id, sse_event, message_queue + ) + + except asyncio.CancelledError: + logger.debug(f"GET stream {stream_id} was cancelled") + raise + except Exception as e: + logger.error(f"GET stream {stream_id} error: {e}") + finally: + logger.debug(f"GET stream {stream_id} closed") + + async def _process_sse_event( + self, + stream_id: str, + sse_event: Any, + message_queue: asyncio.Queue[ServerMessage], + ) -> None: + """Process a single SSE event and queue the message.""" + try: + message_data = json.loads(sse_event.data) + + server_message = ServerMessage( + server_id=self.server_id, + payload=message_data, + timestamp=asyncio.get_event_loop().time(), + metadata={ + "stream_id": stream_id, + "sse_event_id": sse_event.id, + } + if sse_event.id + else {"stream_id": stream_id}, + ) + + await message_queue.put(server_message) + + logger.debug( + f"Stream {stream_id} received: {message_data.get('method', 'response')}" + ) + + except json.JSONDecodeError as e: + logger.warning(f"Stream {stream_id} JSON parse error: {e}") + + def cancel_all_streams(self) -> None: + """Cancel all active streams for this server.""" + if not self._active_streams: + return + + cancelled_count = 0 + for stream_task in list(self._active_streams): + if not stream_task.done(): + stream_task.cancel() + cancelled_count += 1 + + if cancelled_count > 0: + logger.debug( + f"Cancelled {cancelled_count} streams for server {self.server_id}" + ) + + @property + def active_stream_count(self) -> int: + """Number of currently active streams.""" + return len([task for task in self._active_streams if not task.done()]) + + def __repr__(self) -> str: + return ( + f"ClientStreamManager(server_id={self.server_id}, " + f"active_streams={self.active_stream_count})" + ) diff --git a/src/conduit/transport/streamable_http/server.py b/src/conduit/transport/streamable_http/server.py index 5ddd00d..bc18263 100644 --- a/src/conduit/transport/streamable_http/server.py +++ b/src/conduit/transport/streamable_http/server.py @@ -107,6 +107,18 @@ async def stop(self) -> None: self._server.should_exit = True await self._server.shutdown() + async def close(self) -> None: + """Close the transport and clean up all resources. + + For HTTP transport, this stops the HTTP server and cleans up all sessions. + Safe to call multiple times. + """ + await self.stop() + + # Clean up sessions and streams + self._session_manager.terminate_all_sessions() + await self._stream_manager.close_all_streams() + # ================================ # Server Transport Interface # ================================ diff --git a/src/conduit/transport/streamable_http/session_manager.py b/src/conduit/transport/streamable_http/session_manager.py index 12a2409..b50f6dd 100644 --- a/src/conduit/transport/streamable_http/session_manager.py +++ b/src/conduit/transport/streamable_http/session_manager.py @@ -69,3 +69,8 @@ def terminate_session(self, session_id: str) -> str | None: def _generate_session_id(self) -> str: """Generate cryptographically secure session ID.""" return secrets.token_urlsafe(32) + + def terminate_all_sessions(self) -> None: + """Terminate all sessions.""" + for session_id in list(self._sessions): + self.terminate_session(session_id) diff --git a/src/conduit/transport/streamable_http/streams.py b/src/conduit/transport/streamable_http/streams.py index 366cffd..0ec348f 100644 --- a/src/conduit/transport/streamable_http/streams.py +++ b/src/conduit/transport/streamable_http/streams.py @@ -201,3 +201,8 @@ def get_active_stream_count(self) -> int: def get_client_stream_count(self, client_id: str) -> int: """Get number of active streams for a client.""" return len(self._client_streams.get(client_id, set())) + + async def close_all_streams(self) -> None: + """Close all streams.""" + for stream_id in list(self._streams): + await self._cleanup_stream(stream_id) diff --git a/tests/client/coordinator/conftest.py b/tests/client/coordinator/conftest.py index 2d953f3..16365d2 100644 --- a/tests/client/coordinator/conftest.py +++ b/tests/client/coordinator/conftest.py @@ -78,6 +78,10 @@ def add_server_message(self, server_id: str, payload: dict[str, Any]) -> None: ) self.server_message_queue.put_nowait(message) + async def close(self) -> None: + """Close the transport and clean up all resources.""" + pass + @pytest.fixture async def mock_transport(): diff --git a/tests/client/session/conftest.py b/tests/client/session/conftest.py index 40e74ae..e63d013 100644 --- a/tests/client/session/conftest.py +++ b/tests/client/session/conftest.py @@ -77,6 +77,10 @@ def clear_sent_messages(self) -> None: """Clear all sent message history.""" self.sent_messages.clear() + async def close(self) -> None: + """Close the transport and clean up all resources.""" + pass + async def yield_to_event_loop(seconds: float = 0.01) -> None: """Let the event loop process pending tasks and callbacks. diff --git a/tests/server/coordinator/conftest.py b/tests/server/coordinator/conftest.py index f291642..cfd126d 100644 --- a/tests/server/coordinator/conftest.py +++ b/tests/server/coordinator/conftest.py @@ -56,6 +56,10 @@ async def _client_message_iterator(self) -> AsyncIterator[ClientMessage]: except asyncio.CancelledError: break + async def close(self) -> None: + """Close the transport and clean up all resources.""" + pass + # Test helpers def add_client_message(self, client_id: str, payload: dict[str, Any]) -> None: """Add a message to the client message queue for testing.""" diff --git a/tests/server/session/conftest.py b/tests/server/session/conftest.py index 4acac3a..1ff97ea 100644 --- a/tests/server/session/conftest.py +++ b/tests/server/session/conftest.py @@ -15,3 +15,7 @@ async def client_messages(self) -> AsyncIterator[ClientMessage]: async def disconnect_client(self, client_id: str) -> None: self.registered_clients.pop(client_id, None) + + async def close(self) -> None: + """Close the transport and clean up all resources.""" + pass