|
24 | 24 | from mcp.client.streamable_http import streamablehttp_client |
25 | 25 | from mcp.server.fastmcp import FastMCP |
26 | 26 | from mcp.server.fastmcp.resources import FunctionResource |
27 | | -from mcp.server.fastmcp.server import Context |
28 | 27 | from mcp.shared.context import RequestContext |
29 | 28 | from mcp.types import ( |
30 | 29 | CreateMessageRequestParams, |
@@ -196,6 +195,33 @@ def complex_prompt(user_query: str, context: str = "general") -> str: |
196 | 195 | # Since FastMCP doesn't support system messages in the same way |
197 | 196 | return f"Context: {context}. Query: {user_query}" |
198 | 197 |
|
| 198 | + # Tool that echoes request headers from context |
| 199 | + @mcp.tool(description="Echo request headers from context") |
| 200 | + def echo_headers(ctx: Context[Any, Any, Request]) -> str: |
| 201 | + """Returns the request headers as JSON.""" |
| 202 | + headers_info = {} |
| 203 | + if ctx.request_context.request: |
| 204 | + # Now the type system knows request is a Starlette Request object |
| 205 | + headers_info = dict(ctx.request_context.request.headers) |
| 206 | + return json.dumps(headers_info) |
| 207 | + |
| 208 | + # Tool that returns full request context |
| 209 | + @mcp.tool(description="Echo request context with custom data") |
| 210 | + def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: |
| 211 | + """Returns request context including headers and custom data.""" |
| 212 | + context_data = { |
| 213 | + "custom_request_id": custom_request_id, |
| 214 | + "headers": {}, |
| 215 | + "method": None, |
| 216 | + "path": None, |
| 217 | + } |
| 218 | + if ctx.request_context.request: |
| 219 | + request = ctx.request_context.request |
| 220 | + context_data["headers"] = dict(request.headers) |
| 221 | + context_data["method"] = request.method |
| 222 | + context_data["path"] = request.url.path |
| 223 | + return json.dumps(context_data) |
| 224 | + |
199 | 225 | return mcp |
200 | 226 |
|
201 | 227 |
|
@@ -432,174 +458,6 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: |
432 | 458 | assert tool_result.content[0].text == "Echo: hello" |
433 | 459 |
|
434 | 460 |
|
435 | | -def make_fastmcp_with_context_app(): |
436 | | - """Create a FastMCP server that can access request context.""" |
437 | | - |
438 | | - mcp = FastMCP(name="ContextServer") |
439 | | - |
440 | | - # Tool that echoes request headers |
441 | | - @mcp.tool(description="Echo request headers from context") |
442 | | - def echo_headers(ctx: Context[Any, Any, Request]) -> str: |
443 | | - """Returns the request headers as JSON.""" |
444 | | - headers_info = {} |
445 | | - if ctx.request_context.request: |
446 | | - # Now the type system knows request is a Starlette Request object |
447 | | - headers_info = dict(ctx.request_context.request.headers) |
448 | | - return json.dumps(headers_info) |
449 | | - |
450 | | - # Tool that returns full request context |
451 | | - @mcp.tool(description="Echo request context with custom data") |
452 | | - def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: |
453 | | - """Returns request context including headers and custom data.""" |
454 | | - context_data = { |
455 | | - "custom_request_id": custom_request_id, |
456 | | - "headers": {}, |
457 | | - "method": None, |
458 | | - "path": None, |
459 | | - } |
460 | | - if ctx.request_context.request: |
461 | | - request = ctx.request_context.request |
462 | | - context_data["headers"] = dict(request.headers) |
463 | | - context_data["method"] = request.method |
464 | | - context_data["path"] = request.url.path |
465 | | - return json.dumps(context_data) |
466 | | - |
467 | | - # Create the SSE app |
468 | | - app = mcp.sse_app() |
469 | | - return mcp, app |
470 | | - |
471 | | - |
472 | | -def run_context_server(server_port: int) -> None: |
473 | | - """Run the context-aware FastMCP server.""" |
474 | | - _, app = make_fastmcp_with_context_app() |
475 | | - server = uvicorn.Server( |
476 | | - config=uvicorn.Config( |
477 | | - app=app, host="127.0.0.1", port=server_port, log_level="error" |
478 | | - ) |
479 | | - ) |
480 | | - print(f"Starting context server on port {server_port}") |
481 | | - server.run() |
482 | | - |
483 | | - |
484 | | -@pytest.fixture() |
485 | | -def context_aware_server(server_port: int) -> Generator[None, None, None]: |
486 | | - """Start the context-aware server in a separate process.""" |
487 | | - proc = multiprocessing.Process( |
488 | | - target=run_context_server, args=(server_port,), daemon=True |
489 | | - ) |
490 | | - print("Starting context-aware server process") |
491 | | - proc.start() |
492 | | - |
493 | | - # Wait for server to be running |
494 | | - max_attempts = 20 |
495 | | - attempt = 0 |
496 | | - print("Waiting for context-aware server to start") |
497 | | - while attempt < max_attempts: |
498 | | - try: |
499 | | - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
500 | | - s.connect(("127.0.0.1", server_port)) |
501 | | - break |
502 | | - except ConnectionRefusedError: |
503 | | - time.sleep(0.1) |
504 | | - attempt += 1 |
505 | | - else: |
506 | | - raise RuntimeError( |
507 | | - f"Context server failed to start after {max_attempts} attempts" |
508 | | - ) |
509 | | - |
510 | | - yield |
511 | | - |
512 | | - print("Killing context-aware server") |
513 | | - proc.kill() |
514 | | - proc.join(timeout=2) |
515 | | - if proc.is_alive(): |
516 | | - print("Context server process failed to terminate") |
517 | | - |
518 | | - |
519 | | -@pytest.mark.anyio |
520 | | -async def test_fast_mcp_with_request_context( |
521 | | - context_aware_server: None, server_url: str |
522 | | -) -> None: |
523 | | - """Test that FastMCP properly propagates request context to tools.""" |
524 | | - # Test with custom headers |
525 | | - custom_headers = { |
526 | | - "Authorization": "Bearer fastmcp-test-token", |
527 | | - "X-Custom-Header": "fastmcp-value", |
528 | | - "X-Request-Id": "req-123", |
529 | | - } |
530 | | - |
531 | | - async with sse_client(server_url + "/sse", headers=custom_headers) as streams: |
532 | | - async with ClientSession(*streams) as session: |
533 | | - # Initialize the session |
534 | | - result = await session.initialize() |
535 | | - assert isinstance(result, InitializeResult) |
536 | | - assert result.serverInfo.name == "ContextServer" |
537 | | - |
538 | | - # Test 1: Call tool that echoes headers |
539 | | - headers_result = await session.call_tool("echo_headers", {}) |
540 | | - assert len(headers_result.content) == 1 |
541 | | - assert isinstance(headers_result.content[0], TextContent) |
542 | | - |
543 | | - headers_data = json.loads(headers_result.content[0].text) |
544 | | - assert headers_data.get("authorization") == "Bearer fastmcp-test-token" |
545 | | - assert headers_data.get("x-custom-header") == "fastmcp-value" |
546 | | - assert headers_data.get("x-request-id") == "req-123" |
547 | | - |
548 | | - # Test 2: Call tool that returns full context |
549 | | - context_result = await session.call_tool( |
550 | | - "echo_context", {"custom_request_id": "test-123"} |
551 | | - ) |
552 | | - assert len(context_result.content) == 1 |
553 | | - assert isinstance(context_result.content[0], TextContent) |
554 | | - |
555 | | - context_data = json.loads(context_result.content[0].text) |
556 | | - assert context_data["custom_request_id"] == "test-123" |
557 | | - assert ( |
558 | | - context_data["headers"].get("authorization") |
559 | | - == "Bearer fastmcp-test-token" |
560 | | - ) |
561 | | - assert context_data["method"] == "POST" # |
562 | | - |
563 | | - |
564 | | -@pytest.mark.anyio |
565 | | -async def test_fast_mcp_request_context_isolation( |
566 | | - context_aware_server: None, server_url: str |
567 | | -) -> None: |
568 | | - """Test that request contexts are isolated between different FastMCP clients.""" |
569 | | - contexts = [] |
570 | | - |
571 | | - # Create multiple clients with different headers |
572 | | - for i in range(3): |
573 | | - headers = { |
574 | | - "Authorization": f"Bearer token-{i}", |
575 | | - "X-Request-Id": f"fastmcp-req-{i}", |
576 | | - "X-Custom-Value": f"value-{i}", |
577 | | - } |
578 | | - |
579 | | - async with sse_client(server_url + "/sse", headers=headers) as streams: |
580 | | - async with ClientSession(*streams) as session: |
581 | | - await session.initialize() |
582 | | - |
583 | | - # Call the tool that returns context |
584 | | - tool_result = await session.call_tool( |
585 | | - "echo_context", {"custom_request_id": f"test-req-{i}"} |
586 | | - ) |
587 | | - |
588 | | - # Parse and store the result |
589 | | - assert len(tool_result.content) == 1 |
590 | | - assert isinstance(tool_result.content[0], TextContent) |
591 | | - context_data = json.loads(tool_result.content[0].text) |
592 | | - contexts.append(context_data) |
593 | | - |
594 | | - # Verify each request had its own isolated context |
595 | | - assert len(contexts) == 3 |
596 | | - for i, ctx in enumerate(contexts): |
597 | | - assert ctx["custom_request_id"] == f"test-req-{i}" |
598 | | - assert ctx["headers"].get("authorization") == f"Bearer token-{i}" |
599 | | - assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}" |
600 | | - assert ctx["headers"].get("x-custom-value") == f"value-{i}" |
601 | | - |
602 | | - |
603 | 461 | @pytest.mark.anyio |
604 | 462 | async def test_fastmcp_streamable_http( |
605 | 463 | streamable_http_server: None, http_server_url: str |
@@ -967,6 +825,30 @@ async def progress_callback( |
967 | 825 | assert isinstance(complex_result, GetPromptResult) |
968 | 826 | assert len(complex_result.messages) >= 1 |
969 | 827 |
|
| 828 | + # Test request context propagation (only works when headers are available) |
| 829 | + |
| 830 | + headers_result = await session.call_tool("echo_headers", {}) |
| 831 | + assert len(headers_result.content) == 1 |
| 832 | + assert isinstance(headers_result.content[0], TextContent) |
| 833 | + |
| 834 | + # If we got headers, verify they exist |
| 835 | + headers_data = json.loads(headers_result.content[0].text) |
| 836 | + # The headers depend on the transport and test setup |
| 837 | + print(f"Received headers: {headers_data}") |
| 838 | + |
| 839 | + # Test 6: Call tool that returns full context |
| 840 | + context_result = await session.call_tool( |
| 841 | + "echo_context", {"custom_request_id": "test-123"} |
| 842 | + ) |
| 843 | + assert len(context_result.content) == 1 |
| 844 | + assert isinstance(context_result.content[0], TextContent) |
| 845 | + |
| 846 | + context_data = json.loads(context_result.content[0].text) |
| 847 | + assert context_data["custom_request_id"] == "test-123" |
| 848 | + # The method should be POST for most transports |
| 849 | + if context_data["method"]: |
| 850 | + assert context_data["method"] == "POST" |
| 851 | + |
970 | 852 |
|
971 | 853 | async def sampling_callback( |
972 | 854 | context: RequestContext[ClientSession, None], |
|
0 commit comments