|
7 | 7 | """ |
8 | 8 |
|
9 | 9 | import logging |
10 | | -from collections.abc import Awaitable, Callable |
| 10 | +from collections.abc import AsyncGenerator, Awaitable, Callable |
11 | 11 | from contextlib import asynccontextmanager |
12 | 12 | from dataclasses import dataclass |
13 | 13 | from datetime import timedelta |
|
35 | 35 | SessionMessageOrError = SessionMessage | Exception |
36 | 36 | StreamWriter = MemoryObjectSendStream[SessionMessageOrError] |
37 | 37 | StreamReader = MemoryObjectReceiveStream[SessionMessage] |
38 | | - |
| 38 | +TerminateCallback = Callable[[], Awaitable[None]] |
| 39 | +GetSessionIdCallback = Callable[[], str | None] |
39 | 40 |
|
40 | 41 | MCP_SESSION_ID = "mcp-session-id" |
41 | 42 | LAST_EVENT_ID = "last-event-id" |
@@ -412,16 +413,27 @@ async def streamablehttp_client( |
412 | 413 | headers: dict[str, Any] | None = None, |
413 | 414 | timeout: timedelta = timedelta(seconds=30), |
414 | 415 | sse_read_timeout: timedelta = timedelta(seconds=60 * 5), |
415 | | -): |
| 416 | +) -> AsyncGenerator[ |
| 417 | + tuple[ |
| 418 | + MemoryObjectReceiveStream[SessionMessage | Exception], |
| 419 | + MemoryObjectSendStream[SessionMessage], |
| 420 | + TerminateCallback, |
| 421 | + GetSessionIdCallback, |
| 422 | + ], |
| 423 | + None, |
| 424 | +]: |
416 | 425 | """ |
417 | 426 | Client transport for StreamableHTTP. |
418 | 427 |
|
419 | 428 | `sse_read_timeout` determines how long (in seconds) the client will wait for a new |
420 | 429 | event before disconnecting. All other HTTP operations are controlled by `timeout`. |
421 | 430 |
|
422 | 431 | Yields: |
423 | | - Tuple of (read_stream, write_stream, terminate_callback, |
424 | | - get_session_id_callback) |
| 432 | + Tuple containing: |
| 433 | + - read_stream: Stream for reading messages from the server |
| 434 | + - write_stream: Stream for sending messages to the server |
| 435 | + - terminate_callback: Async function to terminate the session - send DELETE |
| 436 | + - get_session_id_callback: Function to retrieve the current session ID |
425 | 437 | """ |
426 | 438 | transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) |
427 | 439 |
|
|
0 commit comments