1111from contextlib import asynccontextmanager
1212from dataclasses import dataclass
1313from datetime import timedelta
14- from typing import Any
1514
1615import anyio
1716import httpx
5251class StreamableHTTPError (Exception ):
5352 """Base exception for StreamableHTTP transport errors."""
5453
55- pass
56-
5754
5855class ResumptionError (StreamableHTTPError ):
5956 """Raised when resumption request is invalid."""
6057
61- pass
62-
6358
6459@dataclass
6560class RequestContext :
@@ -71,7 +66,7 @@ class RequestContext:
7166 session_message : SessionMessage
7267 metadata : ClientMessageMetadata | None
7368 read_stream_writer : StreamWriter
74- sse_read_timeout : timedelta
69+ sse_read_timeout : float
7570
7671
7772class StreamableHTTPTransport :
@@ -80,9 +75,9 @@ class StreamableHTTPTransport:
8075 def __init__ (
8176 self ,
8277 url : str ,
83- headers : dict [str , Any ] | None = None ,
84- timeout : timedelta = timedelta ( seconds = 30 ) ,
85- sse_read_timeout : timedelta = timedelta ( seconds = 60 * 5 ) ,
78+ headers : dict [str , str ] | None = None ,
79+ timeout : float | timedelta = 30 ,
80+ sse_read_timeout : float | timedelta = 60 * 5 ,
8681 auth : httpx .Auth | None = None ,
8782 ) -> None :
8883 """Initialize the StreamableHTTP transport.
@@ -96,10 +91,12 @@ def __init__(
9691 """
9792 self .url = url
9893 self .headers = headers or {}
99- self .timeout = timeout
100- self .sse_read_timeout = sse_read_timeout
94+ self .timeout = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
95+ self .sse_read_timeout = (
96+ sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
97+ )
10198 self .auth = auth
102- self .session_id : str | None = None
99+ self .session_id = None
103100 self .request_headers = {
104101 ACCEPT : f"{ JSON } , { SSE } " ,
105102 CONTENT_TYPE : JSON ,
@@ -160,7 +157,7 @@ async def _handle_sse_event(
160157 return isinstance (message .root , JSONRPCResponse | JSONRPCError )
161158
162159 except Exception as exc :
163- logger .error ( f "Error parsing SSE message: { exc } " )
160+ logger .exception ( "Error parsing SSE message" )
164161 await read_stream_writer .send (exc )
165162 return False
166163 else :
@@ -184,10 +181,7 @@ async def handle_get_stream(
184181 "GET" ,
185182 self .url ,
186183 headers = headers ,
187- timeout = httpx .Timeout (
188- self .timeout .total_seconds (),
189- read = self .sse_read_timeout .total_seconds (),
190- ),
184+ timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
191185 ) as event_source :
192186 event_source .response .raise_for_status ()
193187 logger .debug ("GET SSE connection established" )
@@ -216,10 +210,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
216210 "GET" ,
217211 self .url ,
218212 headers = headers ,
219- timeout = httpx .Timeout (
220- self .timeout .total_seconds (),
221- read = ctx .sse_read_timeout .total_seconds (),
222- ),
213+ timeout = httpx .Timeout (self .timeout , read = self .sse_read_timeout ),
223214 ) as event_source :
224215 event_source .response .raise_for_status ()
225216 logger .debug ("Resumption GET SSE connection established" )
@@ -412,9 +403,9 @@ def get_session_id(self) -> str | None:
412403@asynccontextmanager
413404async def streamablehttp_client (
414405 url : str ,
415- headers : dict [str , Any ] | None = None ,
416- timeout : timedelta = timedelta ( seconds = 30 ) ,
417- sse_read_timeout : timedelta = timedelta ( seconds = 60 * 5 ) ,
406+ headers : dict [str , str ] | None = None ,
407+ timeout : float | timedelta = 30 ,
408+ sse_read_timeout : float | timedelta = 60 * 5 ,
418409 terminate_on_close : bool = True ,
419410 httpx_client_factory : McpHttpClientFactory = create_mcp_http_client ,
420411 auth : httpx .Auth | None = None ,
@@ -449,10 +440,7 @@ async def streamablehttp_client(
449440
450441 async with httpx_client_factory (
451442 headers = transport .request_headers ,
452- timeout = httpx .Timeout (
453- transport .timeout .total_seconds (),
454- read = transport .sse_read_timeout .total_seconds (),
455- ),
443+ timeout = httpx .Timeout (transport .timeout , read = transport .sse_read_timeout ),
456444 auth = transport .auth ,
457445 ) as client :
458446 # Define callbacks that need access to tg
0 commit comments