Skip to content

Commit d1bd44b

Browse files
committed
terminate on close instead of callback to terminate
1 parent dacd294 commit d1bd44b

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
SessionMessageOrError = SessionMessage | Exception
3636
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
3737
StreamReader = MemoryObjectReceiveStream[SessionMessage]
38-
TerminateCallback = Callable[[], Awaitable[None]]
3938
GetSessionIdCallback = Callable[[], str | None]
4039

4140
MCP_SESSION_ID = "mcp-session-id"
@@ -413,11 +412,11 @@ async def streamablehttp_client(
413412
headers: dict[str, Any] | None = None,
414413
timeout: timedelta = timedelta(seconds=30),
415414
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
415+
terminate_on_close: bool = True,
416416
) -> AsyncGenerator[
417417
tuple[
418418
MemoryObjectReceiveStream[SessionMessage | Exception],
419419
MemoryObjectSendStream[SessionMessage],
420-
TerminateCallback,
421420
GetSessionIdCallback,
422421
],
423422
None,
@@ -432,7 +431,6 @@ async def streamablehttp_client(
432431
Tuple containing:
433432
- read_stream: Stream for reading messages from the server
434433
- write_stream: Stream for sending messages to the server
435-
- terminate_callback: Async function to terminate the session - send DELETE
436434
- get_session_id_callback: Function to retrieve the current session ID
437435
"""
438436
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
@@ -461,9 +459,6 @@ def start_get_stream() -> None:
461459
transport.handle_get_stream, client, read_stream_writer
462460
)
463461

464-
async def terminate_session() -> None:
465-
await transport.terminate_session(client)
466-
467462
tg.start_soon(
468463
transport.post_writer,
469464
client,
@@ -477,10 +472,11 @@ async def terminate_session() -> None:
477472
yield (
478473
read_stream,
479474
write_stream,
480-
terminate_session,
481475
transport.get_session_id,
482476
)
483477
finally:
478+
if transport.session_id and terminate_on_close:
479+
await transport.terminate_session(client)
484480
tg.cancel_scope.cancel()
485481
finally:
486482
await read_stream_writer.aclose()

tests/shared/test_streamable_http.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,6 @@ async def initialized_client_session(basic_server, basic_server_url):
779779
read_stream,
780780
write_stream,
781781
_,
782-
_,
783782
):
784783
async with ClientSession(
785784
read_stream,
@@ -796,7 +795,6 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server
796795
read_stream,
797796
write_stream,
798797
_,
799-
_,
800798
):
801799
async with ClientSession(
802800
read_stream,
@@ -854,7 +852,6 @@ async def test_streamablehttp_client_session_persistence(
854852
read_stream,
855853
write_stream,
856854
_,
857-
_,
858855
):
859856
async with ClientSession(
860857
read_stream,
@@ -885,7 +882,6 @@ async def test_streamablehttp_client_json_response(
885882
read_stream,
886883
write_stream,
887884
_,
888-
_,
889885
):
890886
async with ClientSession(
891887
read_stream,
@@ -928,7 +924,6 @@ async def message_handler(
928924
read_stream,
929925
write_stream,
930926
_,
931-
_,
932927
):
933928
async with ClientSession(
934929
read_stream, write_stream, message_handler=message_handler
@@ -961,24 +956,36 @@ async def test_streamablehttp_client_session_termination(
961956
):
962957
"""Test client session termination functionality."""
963958

959+
captured_session_id = None
960+
964961
# Create the streamablehttp_client with a custom httpx client to capture headers
965962
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
966963
read_stream,
967964
write_stream,
968-
terminate_session,
969-
_,
965+
get_session_id,
970966
):
971967
async with ClientSession(read_stream, write_stream) as session:
972968
# Initialize the session
973969
result = await session.initialize()
974970
assert isinstance(result, InitializeResult)
971+
captured_session_id = get_session_id()
972+
assert captured_session_id is not None
975973

976974
# Make a request to confirm session is working
977975
tools = await session.list_tools()
978976
assert len(tools.tools) == 3
979977

980-
# After exiting ClientSession context, explicitly terminate the session
981-
await terminate_session()
978+
headers = {}
979+
if captured_session_id:
980+
headers[MCP_SESSION_ID_HEADER] = captured_session_id
981+
982+
async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (
983+
read_stream,
984+
write_stream,
985+
_,
986+
):
987+
async with ClientSession(read_stream, write_stream) as session:
988+
# Attempt to make a request after termination
982989
with pytest.raises(
983990
McpError,
984991
match="Session terminated",
@@ -1015,10 +1022,9 @@ async def on_resumption_token_update(token: str) -> None:
10151022
captured_resumption_token = token
10161023

10171024
# First, start the client session and begin the long-running tool
1018-
async with streamablehttp_client(f"{server_url}/mcp") as (
1025+
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
10191026
read_stream,
10201027
write_stream,
1021-
_,
10221028
get_session_id,
10231029
):
10241030
async with ClientSession(
@@ -1072,7 +1078,6 @@ async def run_tool():
10721078
read_stream,
10731079
write_stream,
10741080
_,
1075-
_,
10761081
):
10771082
async with ClientSession(
10781083
read_stream, write_stream, message_handler=message_handler

0 commit comments

Comments
 (0)