diff --git a/README.md b/README.md index 5dbc4bd9d..5cbb6510f 100644 --- a/README.md +++ b/README.md @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -254,7 +254,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan) # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() @@ -326,13 +326,13 @@ Tools can optionally receive a Context object by including a parameter with the ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -674,13 +674,13 @@ The Context object provides the following capabilities: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -814,7 +814,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": @@ -888,13 +888,13 @@ Tools can send logs and notifications through the context: ```python from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") @@ -2037,7 +2037,7 @@ import os from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -2051,7 +2051,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( @@ -2167,7 +2167,7 @@ cd to the `examples/snippets` directory and run: import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.shared.metadata_utils import get_display_name @@ -2179,7 +2179,7 @@ server_params = StdioServerParameters( ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -2191,7 +2191,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 5987a878e..6c7201e04 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -17,7 +17,7 @@ from urllib.parse import parse_qs, urlparse from mcp.client.auth import OAuthClientProvider, TokenStorage -from mcp.client.session import ClientSession +from mcp.client.session import ClientSession, ClientTransportSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken @@ -153,7 +153,7 @@ class SimpleAuthClient: def __init__(self, server_url: str, transport_type: str = "streamable-http"): self.server_url = server_url self.transport_type = transport_type - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None async def connect(self): """Connect to the MCP server.""" diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 78a81a4d9..3a9d201b1 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -10,6 +10,7 @@ from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.transport_session import ClientTransportSession # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -67,7 +68,7 @@ def __init__(self, name: str, config: dict[str, Any]) -> None: self.name: str = name self.config: dict[str, Any] = config self.stdio_context: Any | None = None - self.session: ClientSession | None = None + self.session: ClientTransportSession | None = None self._cleanup_lock: asyncio.Lock = asyncio.Lock() self.exit_stack: AsyncExitStack = AsyncExitStack() diff --git a/examples/snippets/clients/display_utilities.py b/examples/snippets/clients/display_utilities.py index 5f1d50510..b8ad7dffc 100644 --- a/examples/snippets/clients/display_utilities.py +++ b/examples/snippets/clients/display_utilities.py @@ -6,7 +6,7 @@ import asyncio import os -from mcp import ClientSession, StdioServerParameters +from mcp import ClientSession, ClientTransportSession, StdioServerParameters from mcp.client.stdio import stdio_client from mcp.shared.metadata_utils import get_display_name @@ -18,7 +18,7 @@ ) -async def display_tools(session: ClientSession): +async def display_tools(session: ClientTransportSession): """Display available tools with human-readable names""" tools_response = await session.list_tools() @@ -30,7 +30,7 @@ async def display_tools(session: ClientSession): print(f" {tool.description}") -async def display_resources(session: ClientSession): +async def display_resources(session: ClientTransportSession): """Display available resources with human-readable names""" resources_response = await session.list_resources() diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index ac978035d..90f9fdff9 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -8,7 +8,7 @@ from pydantic import AnyUrl -from mcp import ClientSession, StdioServerParameters, types +from mcp import ClientSession, ClientTransportSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext @@ -22,7 +22,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/snippets/servers/elicitation.py b/examples/snippets/servers/elicitation.py index 2c8a3b35a..45f2cb68b 100644 --- a/examples/snippets/servers/elicitation.py +++ b/examples/snippets/servers/elicitation.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Elicitation Example") @@ -17,7 +17,7 @@ class BookingPreferences(BaseModel): @mcp.tool() -async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str: +async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str: """Book a table with date availability check.""" # Check if date is available if date == "2024-12-25": diff --git a/examples/snippets/servers/lifespan_example.py b/examples/snippets/servers/lifespan_example.py index 62278b6aa..46f01f427 100644 --- a/examples/snippets/servers/lifespan_example.py +++ b/examples/snippets/servers/lifespan_example.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession # Mock database class for example @@ -51,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # Access type-safe lifespan context in tools @mcp.tool() -def query_db(ctx: Context[ServerSession, AppContext]) -> str: +def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str: """Tool that uses initialized resources.""" db = ctx.request_context.lifespan_context.db return db.query() diff --git a/examples/snippets/servers/notifications.py b/examples/snippets/servers/notifications.py index 833bc8905..995ecd817 100644 --- a/examples/snippets/servers/notifications.py +++ b/examples/snippets/servers/notifications.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Notifications Example") @mcp.tool() -async def process_data(data: str, ctx: Context[ServerSession, None]) -> str: +async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str: """Process data with logging.""" # Different log levels await ctx.debug(f"Debug: Processing '{data}'") diff --git a/examples/snippets/servers/tool_progress.py b/examples/snippets/servers/tool_progress.py index 2ac458f6a..a0f62fda6 100644 --- a/examples/snippets/servers/tool_progress.py +++ b/examples/snippets/servers/tool_progress.py @@ -1,11 +1,11 @@ from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession mcp = FastMCP(name="Progress Example") @mcp.tool() -async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str: +async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str: """Execute a task with progress updates.""" await ctx.info(f"Starting: {task_name}") diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index e93b95c90..93ef8acdf 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,8 +1,10 @@ from .client.session import ClientSession from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client +from .client.transport_session import ClientTransportSession from .server.session import ServerSession from .server.stdio import stdio_server +from .server.transport_session import ServerTransportSession from .shared.exceptions import McpError from .types import ( CallToolRequest, @@ -113,4 +115,6 @@ "stdio_server", "CompleteRequest", "JSONRPCResponse", + "ClientTransportSession", + "ServerTransportSession", ] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3835a2a57..0bd4e9608 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -9,6 +9,7 @@ from typing_extensions import deprecated import mcp.types as types +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder @@ -22,7 +23,7 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch @@ -30,14 +31,14 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] + self, context: RequestContext["ClientTransportSession", Any] ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch @@ -62,7 +63,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.ErrorData( @@ -72,7 +73,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( # pragma: no cover @@ -82,7 +83,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -100,13 +101,14 @@ async def _default_logging_callback( class ClientSession( + ClientTransportSession, BaseSession[ types.ClientRequest, types.ClientNotification, types.ClientResult, types.ServerRequest, types.ServerNotification, - ] + ], ): def __init__( self, @@ -508,7 +510,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover await self.send_notification(types.ClientNotification(types.RootsListChangedNotification())) async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession, Any]( + ctx = RequestContext[ClientTransportSession, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, diff --git a/src/mcp/client/transport_session.py b/src/mcp/client/transport_session.py new file mode 100644 index 000000000..07389d59a --- /dev/null +++ b/src/mcp/client/transport_session.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from datetime import timedelta +from typing import Any + +from pydantic import AnyUrl + +import mcp.types as types +from mcp.shared.session import ProgressFnT + + +class ClientTransportSession(ABC): + """Abstract base class for communication transports.""" + + @abstractmethod + async def initialize(self) -> types.InitializeResult: + """Send an initialize request.""" + raise NotImplementedError + + @abstractmethod + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + raise NotImplementedError + + @abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Send a progress notification.""" + raise NotImplementedError + + @abstractmethod + async def set_logging_level( + self, + level: types.LoggingLevel, + ) -> types.EmptyResult: + """Send a logging/setLevel request.""" + raise NotImplementedError + + @abstractmethod + async def list_resources( + self, + cursor: str | None = None, + ) -> types.ListResourcesResult: + """Send a resources/list request.""" + raise NotImplementedError + + @abstractmethod + async def list_resource_templates( + self, + cursor: str | None = None, + ) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request.""" + raise NotImplementedError + + @abstractmethod + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Send a resources/read request.""" + raise NotImplementedError + + @abstractmethod + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + raise NotImplementedError + + @abstractmethod + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + raise NotImplementedError + + @abstractmethod + async def call_tool( + self, + name: str, + arguments: Any | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.CallToolResult: + """Send a tools/call request with optional progress callback support.""" + raise NotImplementedError + + @abstractmethod + async def list_prompts( + self, + cursor: str | None = None, + ) -> types.ListPromptsResult: + """Send a prompts/list request.""" + raise NotImplementedError + + @abstractmethod + async def get_prompt( + self, + name: str, + arguments: dict[str, str] | None = None, + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + raise NotImplementedError + + @abstractmethod + async def complete( + self, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + context_arguments: dict[str, str] | None = None, + ) -> types.CompleteResult: + """Send a completion/complete request.""" + raise NotImplementedError + + @abstractmethod + async def list_tools( + self, + cursor: str | None = None, + *, + params: types.PaginatedRequestParams | None = None, + ) -> types.ListToolsResult: + """Send a tools/list request. + + Args: + cursor: Simple cursor string for pagination (deprecated, use params instead) + params: Full pagination parameters including cursor and any future fields + """ + raise NotImplementedError + + @abstractmethod + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + raise NotImplementedError diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index bba988f49..b2f33ec7c 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool: async def elicit_with_validation( - session: ServerSession, + session: ServerTransportSession, message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 865b8e7e7..cc05403dd 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -54,12 +54,13 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSession, ServerSessionT +from mcp.server.session import ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt @@ -315,7 +316,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: + def get_context(self) -> Context[ServerTransportSession, LifespanResultT, Request]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 49d289fb7..85846afc6 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -86,6 +86,7 @@ async def main(): from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ServerMessageMetadata, SessionMessage @@ -102,7 +103,9 @@ async def main(): CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") +request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar( + "request_ctx" +) class NotificationOptions: @@ -231,7 +234,7 @@ def get_capabilities( @property def request_context( self, - ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: + ) -> RequestContext[ServerTransportSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index a1bfadc9f..9456ebf9f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions +from mcp.server.transport_session import ServerTransportSession from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, @@ -61,7 +62,7 @@ class InitializationState(Enum): Initialized = 3 -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") +ServerSessionT = TypeVar("ServerSessionT", bound="ServerTransportSession") ServerRequestResponder = ( RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception @@ -69,13 +70,14 @@ class InitializationState(Enum): class ServerSession( + ServerTransportSession, BaseSession[ types.ServerRequest, types.ServerNotification, types.ServerResult, types.ClientRequest, types.ClientNotification, - ] + ], ): _initialized: InitializationState = InitializationState.NotInitialized _client_params: types.InitializeRequestParams | None = None diff --git a/src/mcp/server/transport_session.py b/src/mcp/server/transport_session.py new file mode 100644 index 000000000..bf3f6a1d1 --- /dev/null +++ b/src/mcp/server/transport_session.py @@ -0,0 +1,75 @@ +"""Abstract base class for transport sessions.""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import AnyUrl + +import mcp.types as types + + +class ServerTransportSession(ABC): + """Abstract base class for transport sessions.""" + + @abstractmethod + async def send_log_message( + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a log message notification.""" + raise NotImplementedError + + @abstractmethod + async def send_resource_updated(self, uri: AnyUrl) -> None: + """Send a resource updated notification.""" + raise NotImplementedError + + @abstractmethod + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request.""" + raise NotImplementedError + + @abstractmethod + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request.""" + raise NotImplementedError + + @abstractmethod + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + raise NotImplementedError + + @abstractmethod + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, + ) -> None: + """Send a progress notification.""" + raise NotImplementedError + + @abstractmethod + async def send_resource_list_changed(self) -> None: + """Send a resource list changed notification.""" + raise NotImplementedError + + @abstractmethod + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + raise NotImplementedError + + @abstractmethod + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + raise NotImplementedError diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index f3006e7d5..7267f4954 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -1,12 +1,17 @@ from dataclasses import dataclass -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParams -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +if TYPE_CHECKING: + from mcp import ClientTransportSession, ServerTransportSession + +SessionT = TypeVar( + "SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession" +) LifespanContextT = TypeVar("LifespanContextT") RequestT = TypeVar("RequestT", default=Any) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 06d404e31..2d203d743 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,7 +13,15 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types -from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.fastmcp import FastMCP from mcp.shared.message import SessionMessage @@ -57,7 +65,7 @@ async def create_connected_server_and_client_session( client_info: types.Implementation | None = None, raise_exceptions: bool = False, elicitation_callback: ElicitationFnT | None = None, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" # TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport", diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 0da0fff07..5acb3b21a 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -1,7 +1,7 @@ import pytest from pydantic import FileUrl -from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSession from mcp.shared.context import RequestContext @@ -31,7 +31,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientTransportSession, None], ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index a3f6affda..9fb6e29c7 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,6 +1,7 @@ import pytest -from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession +from mcp.server.session import ServerSession from mcp.shared.context import RequestContext from mcp.shared.memory import ( create_connected_server_and_client_session as create_session, @@ -27,14 +28,16 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return @server.tool("test_sampling") async def test_sampling_tool(message: str): - value = await server.get_context().session.create_message( + session = server.get_context().session + assert isinstance(session, ServerSession) + value = await session.create_message( messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8d0ef68a9..bd51e4e10 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -5,6 +5,7 @@ import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -427,7 +428,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -437,7 +438,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientTransportSession", Any], ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py index 2c74d0e88..52e6799b7 100644 --- a/tests/server/fastmcp/test_elicitation.py +++ b/tests/server/fastmcp/test_elicitation.py @@ -7,9 +7,10 @@ import pytest from pydantic import BaseModel, Field -from mcp.client.session import ClientSession, ElicitationFnT +from mcp.client.session import ElicitationFnT +from mcp.client.transport_session import ClientTransportSession from mcp.server.fastmcp import Context, FastMCP -from mcp.server.session import ServerSession +from mcp.server.transport_session import ServerTransportSession from mcp.shared.context import RequestContext from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -24,7 +25,7 @@ def create_ask_user_tool(mcp: FastMCP): """Create a standard ask_user tool that handles all elicitation responses.""" @mcp.tool(description="A tool that uses elicitation") - async def ask_user(prompt: str, ctx: Context[ServerSession, None]) -> str: + async def ask_user(prompt: str, ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) if result.action == "accept" and result.data: @@ -72,7 +73,7 @@ async def test_stdio_elicitation(): # Create a custom handler for elicitation requests async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) @@ -90,7 +91,7 @@ async def test_stdio_elicitation_decline(): mcp = FastMCP(name="StdioElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -105,7 +106,7 @@ async def test_elicitation_schema_validation(): def create_validation_tool(name: str, schema_class: type[BaseModel]): @mcp.tool(name=name, description=f"Tool testing {name}") - async def tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail validation", schema=schema_class) return "Should not reach here" @@ -129,7 +130,7 @@ class InvalidNestedSchema(BaseModel): # Dummy callback (won't be called due to validation failure) async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -159,7 +160,7 @@ class OptionalSchema(BaseModel): subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") @mcp.tool(description="Tool with optional fields") - async def optional_tool(ctx: Context[ServerSession, None]) -> str: + async def optional_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) if result.action == "accept" and result.data: @@ -189,7 +190,7 @@ async def optional_tool(ctx: Context[ServerSession, None]) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -200,7 +201,7 @@ class InvalidOptionalSchema(BaseModel): optional_list: list[str] | None = Field(default=None, description="Invalid optional list") @mcp.tool(description="Tool with invalid optional field") - async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pragma: no cover + async def invalid_optional_tool(ctx: Context[ServerTransportSession, None]) -> str: # pragma: no cover try: await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) return "Should not reach here" @@ -208,7 +209,7 @@ async def invalid_optional_tool(ctx: Context[ServerSession, None]) -> str: # pr return f"Validation failed: {str(e)}" async def elicitation_callback( - context: RequestContext[ClientSession, None], params: ElicitRequestParams + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams ): # pragma: no cover return ElicitResult(action="accept", content={}) @@ -233,7 +234,7 @@ class DefaultsSchema(BaseModel): email: str = Field(description="Email address (required)") @mcp.tool(description="Tool with default values") - async def defaults_tool(ctx: Context[ServerSession, None]) -> str: + async def defaults_tool(ctx: Context[ServerTransportSession, None]) -> str: result = await ctx.elicit(message="Please provide your information", schema=DefaultsSchema) if result.action == "accept" and result.data: @@ -245,7 +246,9 @@ async def defaults_tool(ctx: Context[ServerSession, None]) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_schema_verify( + context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams + ): # Verify the schema includes defaults schema = params.requestedSchema props = schema["properties"] @@ -266,7 +269,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession, None], p ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession, None], params: ElicitRequestParams): + async def callback_override(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index b1cefca29..d95d3a380 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -35,6 +35,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -212,7 +213,7 @@ def unpack_streams( # Callback functions for testing async def sampling_callback( - context: RequestContext[ClientSession, None], params: CreateMessageRequestParams + context: RequestContext[ClientTransportSession, None], params: CreateMessageRequestParams ) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( @@ -225,7 +226,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession, None], params: ElicitRequestParams): +async def elicitation_callback(context: RequestContext[ClientTransportSession, None], params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/server/test_cancel_handling.py b/tests/server/test_cancel_handling.py index 47c49bb62..b1f825933 100644 --- a/tests/server/test_cancel_handling.py +++ b/tests/server/test_cancel_handling.py @@ -6,6 +6,7 @@ import pytest import mcp.types as types +from mcp.client.session import ClientSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_connected_server_and_client_session @@ -56,6 +57,8 @@ async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[ async with create_connected_server_and_client_session(server) as client: # First request (will be cancelled) + assert isinstance(client, ClientSession) + async def first_request(): try: await client.send_request( diff --git a/tests/shared/test_memory.py b/tests/shared/test_memory.py index ca4368e9f..56e0b98e7 100644 --- a/tests/shared/test_memory.py +++ b/tests/shared/test_memory.py @@ -3,6 +3,7 @@ from typing_extensions import AsyncGenerator from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.shared.memory import create_connected_server_and_client_session from mcp.types import EmptyResult, Resource @@ -28,7 +29,7 @@ async def handle_list_resources(): # pragma: no cover @pytest.fixture async def client_connected_to_server( mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: async with create_connected_server_and_client_session(mcp_server) as client_session: yield client_session diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1552711d2..25afd7f32 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -370,6 +370,7 @@ async def handle_list_tools() -> list[types.Tool]: with patch("mcp.shared.session.logging.error", side_effect=mock_log_error): async with create_connected_server_and_client_session(server) as client_session: # Send a request with a failing progress callback + assert isinstance(client_session, ClientSession) result = await client_session.send_request( types.ClientRequest( types.CallToolRequest( diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 313ec9926..a056f705b 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -6,6 +6,7 @@ import mcp.types as types from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.server.lowlevel.server import Server from mcp.shared.exceptions import McpError from mcp.shared.memory import create_client_server_memory_streams, create_connected_server_and_client_session @@ -27,19 +28,20 @@ def mcp_server() -> Server: @pytest.fixture async def client_connected_to_server( mcp_server: Server, -) -> AsyncGenerator[ClientSession, None]: +) -> AsyncGenerator[ClientTransportSession, None]: async with create_connected_server_and_client_session(mcp_server) as client_session: yield client_session @pytest.mark.anyio async def test_in_flight_requests_cleared_after_completion( - client_connected_to_server: ClientSession, + client_connected_to_server: ClientTransportSession, ): """Verify that _in_flight is empty after all requests complete.""" # Send a request and wait for response response = await client_connected_to_server.send_ping() assert isinstance(response, EmptyResult) + assert isinstance(client_connected_to_server, ClientSession) # Verify _in_flight is empty assert len(client_connected_to_server._in_flight) == 0 @@ -101,6 +103,7 @@ async def make_request(client_session: ClientSession): async with create_connected_server_and_client_session(make_server()) as client_session: async with anyio.create_task_group() as tg: + assert isinstance(client_session, ClientSession) tg.start_soon(make_request, client_session) # Wait for the request to be in-flight diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 28ac07d09..0f850599a 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -19,6 +19,7 @@ import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings @@ -185,7 +186,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.fixture -async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 43b321d96..be80e3820 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -23,7 +23,9 @@ import mcp.types as types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client +from mcp.client.transport_session import ClientTransportSession from mcp.server import Server +from mcp.server.session import ServerSession from mcp.server.streamable_http import ( MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, @@ -198,7 +200,9 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent] elif name == "test_sampling_tool": # Test sampling by requesting the client to sample a message - sampling_result = await ctx.session.create_message( + session = ctx.session + assert isinstance(session, ServerSession) + sampling_result = await session.create_message( messages=[ types.SamplingMessage( role="user", @@ -1233,7 +1237,7 @@ async def test_streamablehttp_server_sampling(basic_server: None, basic_server_u # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession, Any], + context: RequestContext[ClientTransportSession, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index f093cb492..107cd5589 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -13,6 +13,7 @@ from starlette.websockets import WebSocket from mcp.client.session import ClientSession +from mcp.client.transport_session import ClientTransportSession from mcp.client.websocket import websocket_client from mcp.server import Server from mcp.server.websocket import websocket_server @@ -125,7 +126,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientTransportSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: