Skip to content

Commit 6e05ad5

Browse files
committed
feat: use contextvar to pass process for easier access
1 parent 70115b9 commit 6e05ad5

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/mcp/client/stdio/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextvars
12
import os
23
import sys
34
from contextlib import asynccontextmanager
@@ -6,6 +7,7 @@
67

78
import anyio
89
import anyio.lowlevel
10+
from anyio.abc import Process
911
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1012
from anyio.streams.text import TextReceiveStream
1113
from pydantic import BaseModel, Field
@@ -92,6 +94,9 @@ class StdioServerParameters(BaseModel):
9294
"""
9395

9496

97+
PROCESS_VAR: contextvars.ContextVar[Process] = contextvars.ContextVar("process")
98+
99+
95100
@asynccontextmanager
96101
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
97102
"""
@@ -169,9 +174,13 @@ async def stdin_writer():
169174
):
170175
tg.start_soon(stdout_reader)
171176
tg.start_soon(stdin_writer)
177+
token = None
172178
try:
179+
token = PROCESS_VAR.set(process)
173180
yield read_stream, write_stream
174181
finally:
182+
if token is not None:
183+
PROCESS_VAR.reset(token)
175184
# Clean up process to prevent any dangling orphaned processes
176185
if sys.platform == "win32":
177186
await terminate_windows_process(process)

tests/client/test_stdio.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from mcp.client.stdio import StdioServerParameters, stdio_client
5+
from mcp.client.stdio import PROCESS_VAR, StdioServerParameters, stdio_client
66
from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse
77

88
tee: str = shutil.which("tee") # type: ignore
@@ -34,6 +34,10 @@ async def test_stdio_client():
3434
if len(read_messages) == 2:
3535
break
3636

37+
process = PROCESS_VAR.get()
38+
assert process is not None
39+
assert process.returncode is None
40+
3741
assert len(read_messages) == 2
3842
assert read_messages[0] == JSONRPCMessage(
3943
root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")

0 commit comments

Comments
 (0)