Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 39 additions & 56 deletions mcpgateway/translate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-

'''Location: ./mcpgateway/translate.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Expand Down Expand Up @@ -123,7 +124,7 @@
import shlex
import signal
import sys
from typing import Any, AsyncIterator, cast, Dict, List, Optional, Sequence, Tuple
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlencode
import uuid

Expand All @@ -146,6 +147,7 @@
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.types import Receive, Scope, Send

# First-Party
from mcpgateway.services.logging_service import LoggingService
Expand Down Expand Up @@ -380,6 +382,10 @@ async def start(self, additional_env_vars: Optional[Dict[str, str]] = None) -> N
>>> asyncio.run(test_start()) # doctest: +SKIP
True
"""
# Stop existing subprocess before starting a new one
if self._proc is not None:
await self.stop()

LOGGER.info(f"Starting stdio subprocess: {self._cmd}")

# Build environment from base + configured + additional
Expand All @@ -388,12 +394,18 @@ async def start(self, additional_env_vars: Optional[Dict[str, str]] = None) -> N
if additional_env_vars:
env.update(additional_env_vars)

# System-critical environment variables that must never be cleared
system_critical_vars = {"PATH", "HOME", "TMPDIR", "TEMP", "TMP", "USER", "LOGNAME", "SHELL", "LANG", "LC_ALL", "LC_CTYPE", "PYTHONHOME", "PYTHONPATH"}

# Clear any mapped env vars that weren't provided in headers to avoid inheritance
if self._header_mappings:
for env_var_name in self._header_mappings.values():
if env_var_name not in (additional_env_vars or {}):
env[env_var_name] = ""
if env_var_name not in (additional_env_vars or {}) and env_var_name not in system_critical_vars:
# Delete the variable instead of setting to empty string to avoid
# breaking subprocess initialization
env.pop(env_var_name, None)

LOGGER.debug(f"Subprocess environment variables: {list(env.keys())}")
self._proc = await asyncio.create_subprocess_exec(
*shlex.split(self._cmd),
stdin=asyncio.subprocess.PIPE,
Expand All @@ -406,6 +418,8 @@ async def start(self, additional_env_vars: Optional[Dict[str, str]] = None) -> N
if not self._proc.stdin or not self._proc.stdout:
raise RuntimeError(f"Failed to create subprocess with stdin/stdout pipes for command: {self._cmd}")

LOGGER.debug("Subprocess started successfully")

self._stdin = self._proc.stdin
self._pump_task = asyncio.create_task(self._pump_stdout())

Expand Down Expand Up @@ -677,7 +691,7 @@ def _build_fastapi(
# Add CORS middleware if origins specified
if cors_origins:
app.add_middleware(
cast("type", CORSMiddleware),
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -1073,7 +1087,7 @@ async def _run_stdio_to_sse(
log_level=log_level,
lifespan="off",
)
server = uvicorn.Server(config)
uvicorn_server = uvicorn.Server(config)

shutting_down = asyncio.Event() # 🔄 make shutdown idempotent

Expand Down Expand Up @@ -1103,24 +1117,15 @@ async def _shutdown() -> None:
await stdio.stop()
# Graceful shutdown by setting the shutdown event
# Use getattr to safely access should_exit attribute
setattr(server, "should_exit", getattr(server, "should_exit", False) or True)
setattr(uvicorn_server, "should_exit", getattr(uvicorn_server, "should_exit", False) or True)

loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
with suppress(NotImplementedError): # Windows lacks add_signal_handler

def shutdown_handler(*args): # pylint: disable=unused-argument
"""Handle shutdown signal by creating shutdown task.

Args:
*args: Signal handler arguments (unused).
"""
asyncio.create_task(_shutdown())

loop.add_signal_handler(sig, shutdown_handler)
loop.add_signal_handler(sig, lambda *_: asyncio.create_task(_shutdown()))

LOGGER.info(f"Bridge ready → http://{host}:{port}{sse_path}")
await server.serve()
await uvicorn_server.serve()
await _shutdown() # final cleanup


Expand Down Expand Up @@ -1377,7 +1382,7 @@ async def _run_stdio_to_streamable_http(
LOGGER.info(f"Starting stdio to streamable HTTP bridge for command: {cmd}")

# Create a simple MCP server that will proxy to stdio subprocess
server = MCPServer(name="stdio-proxy")
mcp_server = MCPServer(name="stdio-proxy")

# Create subprocess for stdio communication
process = await asyncio.create_subprocess_exec(
Expand All @@ -1392,13 +1397,13 @@ async def _run_stdio_to_streamable_http(

# Set up the streamable HTTP session manager with the server
session_manager = StreamableHTTPSessionManager(
app=server,
app=mcp_server,
stateless=stateless,
json_response=json_response,
)

# Create Starlette app to host the streamable HTTP endpoint
async def handle_mcp(request) -> None:
async def handle_mcp(request: Request) -> None:
"""Handle MCP requests via streamable HTTP.

Args:
Expand All @@ -1418,8 +1423,8 @@ async def handle_mcp(request) -> None:
>>> asyncio.run(test_handle())
True
"""
# The session manager handles all the protocol details
await session_manager.handle_request(request.scope, request.receive, request.send)
# The session manager handles all the protocol details - Note: I don't like accessing _send directly -JPS
await session_manager.handle_request(request.scope, request.receive, request._send) # pylint: disable=W0212

routes = [
Route("/mcp", handle_mcp, methods=["GET", "POST"]),
Expand All @@ -1430,12 +1435,8 @@ async def handle_mcp(request) -> None:

# Add CORS middleware if specified
if cors:
# Import here to avoid unnecessary dependency when CORS not used
# Third-Party
from starlette.middleware.cors import CORSMiddleware as StarletteCORS # pylint: disable=import-outside-toplevel

app.add_middleware(
cast("type", StarletteCORS),
CORSMiddleware,
allow_origins=cors,
allow_credentials=True,
allow_methods=["*"],
Expand All @@ -1450,7 +1451,7 @@ async def handle_mcp(request) -> None:
log_level=log_level,
lifespan="off",
)
server = uvicorn.Server(config)
uvicorn_server = uvicorn.Server(config)

shutting_down = asyncio.Event()

Expand All @@ -1466,21 +1467,12 @@ async def _shutdown() -> None:
await asyncio.wait_for(process.wait(), 5)
# Graceful shutdown by setting the shutdown event
# Use getattr to safely access should_exit attribute
setattr(server, "should_exit", getattr(server, "should_exit", False) or True)
setattr(uvicorn_server, "should_exit", getattr(uvicorn_server, "should_exit", False) or True)

loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
with suppress(NotImplementedError): # Windows lacks add_signal_handler

def shutdown_handler(*args): # pylint: disable=unused-argument
"""Handle shutdown signal by creating shutdown task.

Args:
*args: Signal handler arguments (unused).
"""
asyncio.create_task(_shutdown())

loop.add_signal_handler(sig, shutdown_handler)
loop.add_signal_handler(sig, lambda *_: asyncio.create_task(_shutdown()))

# Pump messages between stdio and HTTP
async def pump_stdio_to_http() -> None:
Expand Down Expand Up @@ -1537,7 +1529,7 @@ async def pump_http_to_stdio(data: str) -> None:

try:
LOGGER.info(f"Streamable HTTP bridge ready → http://{host}:{port}/mcp")
await server.serve()
await uvicorn_server.serve()
finally:
pump_task.cancel()
await _shutdown()
Expand Down Expand Up @@ -1816,7 +1808,7 @@ async def _run_multi_protocol_server( # pylint: disable=too-many-positional-arg
# Add CORS middleware if specified
if cors:
app.add_middleware(
cast("type", CORSMiddleware),
CORSMiddleware,
allow_origins=cors,
allow_credentials=True,
allow_methods=["*"],
Expand Down Expand Up @@ -2060,7 +2052,7 @@ async def mcp_post(request: Request) -> Response:
return PlainTextResponse("accepted", status_code=status.HTTP_202_ACCEPTED)

# ASGI wrapper to route GET/other /mcp scopes to streamable_manager.handle_request
async def mcp_asgi_wrapper(scope, receive, send):
async def mcp_asgi_wrapper(scope: Scope, receive: Receive, send: Send) -> None:
"""
ASGI middleware that intercepts HTTP requests to the `/mcp` endpoint.

Expand All @@ -2069,9 +2061,9 @@ async def mcp_asgi_wrapper(scope, receive, send):
passed to the original FastAPI application.

Args:
scope (dict): The ASGI scope dictionary containing request metadata.
receive (Callable): An awaitable that yields incoming ASGI events.
send (Callable): An awaitable used to send ASGI events.
scope (Scope): The ASGI scope dictionary containing request metadata.
receive (Receive): An awaitable that yields incoming ASGI events.
send (Send): An awaitable used to send ASGI events.
"""
if scope.get("type") == "http" and scope.get("path") == "/mcp" and streamable_manager:
# Let StreamableHTTPSessionManager handle session-oriented streaming
Expand All @@ -2082,7 +2074,7 @@ async def mcp_asgi_wrapper(scope, receive, send):
await original_app(scope, receive, send)

# Replace the app used by uvicorn with the ASGI wrapper
app = mcp_asgi_wrapper
app = mcp_asgi_wrapper # type: ignore[assignment]

# ---------------------- Server lifecycle ----------------------
config = uvicorn.Config(
Expand Down Expand Up @@ -2112,16 +2104,7 @@ async def _shutdown() -> None:
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
with suppress(NotImplementedError):

def shutdown_handler(*args): # pylint: disable=unused-argument
"""Handle shutdown signal by creating shutdown task.

Args:
*args: Signal handler arguments (unused).
"""
asyncio.create_task(_shutdown())

loop.add_signal_handler(sig, shutdown_handler)
loop.add_signal_handler(sig, lambda *_: asyncio.create_task(_shutdown()))

# If we have a streamable manager, start its context so it can accept ASGI /mcp
if streamable_manager:
Expand Down
48 changes: 23 additions & 25 deletions tests/unit/mcpgateway/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,12 @@ def _fake_main(argv=None):
assert executed == ["main_called"]


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_main_function_stdio(monkeypatch, translate):
"""Test main() function with --stdio argument."""
"""Test main() function with --stdio argument.

Note: This test closes coroutines which may generate RuntimeWarnings during garbage collection.
"""
executed: list[str] = []

async def _fake_stdio_runner(*args):
Expand All @@ -982,8 +986,12 @@ def _fake_asyncio_run(coro):
assert "asyncio_run" in executed


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_main_function_sse(monkeypatch, translate):
"""Test main() function with --sse argument."""
"""Test main() function with --sse argument.

Note: This test closes coroutines which may generate RuntimeWarnings during garbage collection.
"""
executed: list[str] = []

async def _fake_sse_runner(*args):
Expand All @@ -1003,8 +1011,13 @@ def _fake_asyncio_run(coro):
assert "asyncio_run" in executed


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_main_function_keyboard_interrupt(monkeypatch, translate, capsys):
"""Test main() function handles KeyboardInterrupt gracefully."""
"""Test main() function handles KeyboardInterrupt gracefully.

Note: This test raises KeyboardInterrupt which prevents the coroutine from being awaited,
resulting in a RuntimeWarning during garbage collection. This is expected behavior.
"""

def _raise_keyboard_interrupt(*args):
raise KeyboardInterrupt()
Expand All @@ -1019,8 +1032,13 @@ def _raise_keyboard_interrupt(*args):
assert captured.out == "\n" # Should print newline to restore shell prompt


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_main_function_not_implemented_error(monkeypatch, translate, capsys):
"""Test main() function handles NotImplementedError."""
"""Test main() function handles NotImplementedError.

Note: This test raises NotImplementedError which prevents the coroutine from being awaited,
resulting in a RuntimeWarning during garbage collection. This is expected behavior.
"""

# def _raise_not_implemented(coro, *a, **kw):
# # close the coroutine if the autouse fixture didn't remove it
Expand Down Expand Up @@ -1405,29 +1423,9 @@ def __init__(self, routes=None):
def add_middleware(self, middleware_class, **kwargs):
calls.append(f"add_middleware_{middleware_class.__name__}")

# Mock Starlette CORS middleware import
class MockCORSMiddleware:
def __init__(self, **kwargs):
pass

# Mock the import path for CORS middleware
# Standard
import types

cors_module = types.ModuleType("cors")
cors_module.CORSMiddleware = MockCORSMiddleware
middleware_module = types.ModuleType("middleware")
middleware_module.cors = cors_module
starlette_module = types.ModuleType("starlette")
starlette_module.middleware = middleware_module

# Standard
import sys

sys.modules["starlette"] = starlette_module
sys.modules["starlette.middleware"] = middleware_module
sys.modules["starlette.middleware.cors"] = cors_module

class MockTask:
def cancel(self):
pass
Expand Down Expand Up @@ -1470,7 +1468,7 @@ async def mock_shutdown():
await translate._run_stdio_to_streamable_http("echo test", 8000, "info", cors=["http://example.com"])

# Verify CORS middleware was added (using our Mock class name)
assert "add_middleware_MockCORSMiddleware" in calls
assert "add_middleware_CORSMiddleware" in calls
finally:
# Clean up sys.modules to avoid affecting other tests
sys.modules.pop("starlette", None)
Expand Down
Loading
Loading