Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dd4939d
add transport abstraction
asheshvidyut Nov 6, 2025
11d1249
fix ruff
asheshvidyut Nov 6, 2025
c8f3a42
fix ruff format
asheshvidyut Nov 6, 2025
03cc6c5
add transport session for server
asheshvidyut Nov 6, 2025
1327a9c
clientsession and server session to implement abstract classes
asheshvidyut Nov 6, 2025
0018679
add raise not implemented
asheshvidyut Nov 6, 2025
af7ff5a
fix abstract server transport session
asheshvidyut Nov 6, 2025
7f468d0
removed unused import
asheshvidyut Nov 6, 2025
e895d90
fix type hints
asheshvidyut Nov 6, 2025
d01e477
revert type hints
asheshvidyut Nov 6, 2025
7bdafa3
fix import
asheshvidyut Nov 6, 2025
e9f63dd
fix import
asheshvidyut Nov 6, 2025
5b156a1
fix ruff format
asheshvidyut Nov 6, 2025
f26d861
request context as optional param
asheshvidyut Nov 6, 2025
3097cb3
fix format
asheshvidyut Nov 6, 2025
9e8dca3
ruff check --fix
asheshvidyut Nov 6, 2025
5b7b458
fix pyright
asheshvidyut Nov 6, 2025
8ca511e
ruff fix
asheshvidyut Nov 6, 2025
53e02fe
removed fat abstract class
asheshvidyut Nov 6, 2025
cf0f152
removed client a thin interface
asheshvidyut Nov 6, 2025
ccbdde8
add description
asheshvidyut Nov 6, 2025
380710e
revert context change in this pr
asheshvidyut Nov 6, 2025
3f977b3
rename classes
asheshvidyut Nov 7, 2025
ec7b6d6
ruff fix
asheshvidyut Nov 7, 2025
0359aa8
merge main
asheshvidyut Nov 12, 2025
b733fcf
fix type hints for serversession
asheshvidyut Nov 7, 2025
cdc39f4
fix ruff
asheshvidyut Nov 7, 2025
65a3b0f
uv run scripts/update_readme_snippets.py
asheshvidyut Nov 7, 2025
f34e8fe
some fixes
asheshvidyut Nov 7, 2025
1bfc086
fix ruff
asheshvidyut Nov 7, 2025
481f7ea
fix type hints without cast
asheshvidyut Nov 7, 2025
6b8f737
fix ruff
asheshvidyut Nov 7, 2025
99856e8
remove overload
asheshvidyut Nov 7, 2025
ea8a33c
revert client session group
asheshvidyut Nov 7, 2025
5bcfe62
fix ruff pyright
asheshvidyut Nov 7, 2025
af6be96
fix ruff
asheshvidyut Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession


# Mock database class for example
Expand Down Expand Up @@ -254,7 +254,7 @@ mcp = FastMCP("My App", lifespan=app_lifespan)

# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str:
"""Tool that uses initialized resources."""
db = ctx.request_context.lifespan_context.db
return db.query()
Expand Down Expand Up @@ -326,13 +326,13 @@ Tools can optionally receive a Context object by including a parameter with the
<!-- snippet-source examples/snippets/servers/tool_progress.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Progress Example")


@mcp.tool()
async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str:
async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str:
"""Execute a task with progress updates."""
await ctx.info(f"Starting: {task_name}")

Expand Down Expand Up @@ -674,13 +674,13 @@ The Context object provides the following capabilities:
<!-- snippet-source examples/snippets/servers/tool_progress.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Progress Example")


@mcp.tool()
async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str:
async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str:
"""Execute a task with progress updates."""
await ctx.info(f"Starting: {task_name}")

Expand Down Expand Up @@ -798,7 +798,7 @@ Request additional information from users. This example shows an Elicitation dur
from pydantic import BaseModel, Field

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Elicitation Example")

Expand All @@ -814,7 +814,7 @@ class BookingPreferences(BaseModel):


@mcp.tool()
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str:
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str:
"""Book a table with date availability check."""
# Check if date is available
if date == "2024-12-25":
Expand Down Expand Up @@ -888,13 +888,13 @@ Tools can send logs and notifications through the context:
<!-- snippet-source examples/snippets/servers/notifications.py -->
```python
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Notifications Example")


@mcp.tool()
async def process_data(data: str, ctx: Context[ServerSession, None]) -> str:
async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str:
"""Process data with logging."""
# Different log levels
await ctx.debug(f"Debug: Processing '{data}'")
Expand Down Expand Up @@ -2038,6 +2038,7 @@ import os
from pydantic import AnyUrl

from mcp import ClientSession, StdioServerParameters, types
from mcp.client.session import ClientTransportSession
from mcp.client.stdio import stdio_client
from mcp.shared.context import RequestContext

Expand All @@ -2051,7 +2052,7 @@ server_params = StdioServerParameters(

# Optional: create a sampling callback
async def handle_sampling_message(
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams
) -> types.CreateMessageResult:
print(f"Sampling request: {params.messages}")
return types.CreateMessageResult(
Expand Down Expand Up @@ -2169,6 +2170,7 @@ import os

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.transport_session import ClientTransportSession
from mcp.shared.metadata_utils import get_display_name

# Create server parameters for stdio connection
Expand All @@ -2179,7 +2181,7 @@ server_params = StdioServerParameters(
)


async def display_tools(session: ClientSession):
async def display_tools(session: ClientTransportSession):
"""Display available tools with human-readable names"""
tools_response = await session.list_tools()

Expand All @@ -2191,7 +2193,7 @@ async def display_tools(session: ClientSession):
print(f" {tool.description}")


async def display_resources(session: ClientSession):
async def display_resources(session: ClientTransportSession):
"""Display available resources with human-readable names"""
resources_response = await session.list_resources()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from urllib.parse import parse_qs, urlparse

from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.session import ClientSession
from mcp.client.session import ClientSession, ClientTransportSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
Expand Down Expand Up @@ -153,7 +153,7 @@ class SimpleAuthClient:
def __init__(self, server_url: str, transport_type: str = "streamable-http"):
self.server_url = server_url
self.transport_type = transport_type
self.session: ClientSession | None = None
self.session: ClientTransportSession | None = None

async def connect(self):
"""Connect to the MCP server."""
Expand Down
3 changes: 2 additions & 1 deletion examples/clients/simple-chatbot/mcp_simple_chatbot/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.transport_session import ClientTransportSession

# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
Expand Down Expand Up @@ -67,7 +68,7 @@ def __init__(self, name: str, config: dict[str, Any]) -> None:
self.name: str = name
self.config: dict[str, Any] = config
self.stdio_context: Any | None = None
self.session: ClientSession | None = None
self.session: ClientTransportSession | None = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack: AsyncExitStack = AsyncExitStack()

Expand Down
5 changes: 3 additions & 2 deletions examples/snippets/clients/display_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.transport_session import ClientTransportSession
from mcp.shared.metadata_utils import get_display_name

# Create server parameters for stdio connection
Expand All @@ -18,7 +19,7 @@
)


async def display_tools(session: ClientSession):
async def display_tools(session: ClientTransportSession):
"""Display available tools with human-readable names"""
tools_response = await session.list_tools()

Expand All @@ -30,7 +31,7 @@ async def display_tools(session: ClientSession):
print(f" {tool.description}")


async def display_resources(session: ClientSession):
async def display_resources(session: ClientTransportSession):
"""Display available resources with human-readable names"""
resources_response = await session.list_resources()

Expand Down
3 changes: 2 additions & 1 deletion examples/snippets/clients/stdio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import AnyUrl

from mcp import ClientSession, StdioServerParameters, types
from mcp.client.session import ClientTransportSession
from mcp.client.stdio import stdio_client
from mcp.shared.context import RequestContext

Expand All @@ -22,7 +23,7 @@

# Optional: create a sampling callback
async def handle_sampling_message(
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams
) -> types.CreateMessageResult:
print(f"Sampling request: {params.messages}")
return types.CreateMessageResult(
Expand Down
4 changes: 2 additions & 2 deletions examples/snippets/servers/elicitation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel, Field

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Elicitation Example")

Expand All @@ -17,7 +17,7 @@ class BookingPreferences(BaseModel):


@mcp.tool()
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str:
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str:
"""Book a table with date availability check."""
# Check if date is available
if date == "2024-12-25":
Expand Down
4 changes: 2 additions & 2 deletions examples/snippets/servers/lifespan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass

from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession


# Mock database class for example
Expand Down Expand Up @@ -51,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:

# Access type-safe lifespan context in tools
@mcp.tool()
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str:
"""Tool that uses initialized resources."""
db = ctx.request_context.lifespan_context.db
return db.query()
4 changes: 2 additions & 2 deletions examples/snippets/servers/notifications.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Notifications Example")


@mcp.tool()
async def process_data(data: str, ctx: Context[ServerSession, None]) -> str:
async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str:
"""Process data with logging."""
# Different log levels
await ctx.debug(f"Debug: Processing '{data}'")
Expand Down
4 changes: 2 additions & 2 deletions examples/snippets/servers/tool_progress.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.session import ServerSession
from mcp.server.session import ServerTransportSession

mcp = FastMCP(name="Progress Example")


@mcp.tool()
async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str:
async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str:
"""Execute a task with progress updates."""
await ctx.info(f"Starting: {task_name}")

Expand Down
4 changes: 4 additions & 0 deletions src/mcp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .client.session import ClientSession
from .client.session_group import ClientSessionGroup
from .client.stdio import StdioServerParameters, stdio_client
from .client.transport_session import ClientTransportSession
from .server.session import ServerSession
from .server.stdio import stdio_server
from .server.transport_session import ServerTransportSession
from .shared.exceptions import McpError
from .types import (
CallToolRequest,
Expand Down Expand Up @@ -113,4 +115,6 @@
"stdio_server",
"CompleteRequest",
"JSONRPCResponse",
"ClientTransportSession",
"ServerTransportSession",
]
18 changes: 10 additions & 8 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import deprecated

import mcp.types as types
from mcp.client.transport_session import ClientTransportSession
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
Expand All @@ -22,22 +23,22 @@
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientTransportSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch


class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientTransportSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
self, context: RequestContext["ClientTransportSession", Any]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch


Expand All @@ -62,7 +63,7 @@ async def _default_message_handler(


async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientTransportSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
Expand All @@ -72,7 +73,7 @@ async def _default_sampling_callback(


async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientTransportSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
Expand All @@ -82,7 +83,7 @@ async def _default_elicitation_callback(


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
context: RequestContext["ClientTransportSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
Expand All @@ -100,13 +101,14 @@ async def _default_logging_callback(


class ClientSession(
ClientTransportSession,
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
],
):
def __init__(
self,
Expand Down Expand Up @@ -508,7 +510,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))

async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
ctx = RequestContext[ClientTransportSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
Expand Down
Loading