Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions src/fastmcp/server/mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,30 +136,46 @@ async def _docket_lifespan(self: FastMCP) -> AsyncIterator[None]:

@asynccontextmanager
async def _lifespan_manager(self: FastMCP) -> AsyncIterator[None]:
if self._lifespan_result_set:
yield
async with self._lifespan_lock:
if self._lifespan_result_set:
self._lifespan_ref_count += 1
should_enter_lifespan = False
else:
self._lifespan_ref_count = 1
should_enter_lifespan = True

if not should_enter_lifespan:
try:
yield
finally:
async with self._lifespan_lock:
self._lifespan_ref_count -= 1
return
Comment on lines +149 to 156

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Clear lifespan state when non-owner context is last to exit

When the first-entered session exits before a later overlapping session, teardown runs in the owner branch, but this non-owner branch only decrements _lifespan_ref_count and returns. If that decrement reaches zero, _lifespan_result_set remains True (and _lifespan_result stays stale), so subsequent sessions take the fast path and skip re-entering lifespan entirely even though resources were already torn down; this leaves future requests running against closed/uninitialized lifespan state.

Useful? React with 👍 / 👎.


async with (
self._lifespan(self) as user_lifespan_result,
self._docket_lifespan(),
):
self._lifespan_result = user_lifespan_result
self._lifespan_result_set = True

async with AsyncExitStack[bool | None]() as stack:
# Start lifespans for all providers
for provider in self.providers:
await stack.enter_async_context(provider.lifespan())

self._started.set()
try:
yield
finally:
self._started.clear()

self._lifespan_result_set = False
self._lifespan_result = None
try:
async with (
self._lifespan(self) as user_lifespan_result,
self._docket_lifespan(),
):
self._lifespan_result = user_lifespan_result
self._lifespan_result_set = True

async with AsyncExitStack[bool | None]() as stack:
# Start lifespans for all providers
for provider in self.providers:
await stack.enter_async_context(provider.lifespan())

self._started.set()
try:
yield
finally:
self._started.clear()
finally:
async with self._lifespan_lock:
self._lifespan_ref_count -= 1
if self._lifespan_ref_count == 0:
self._lifespan_result_set = False
self._lifespan_result = None

def _setup_task_protocol_handlers(self: FastMCP) -> None:
"""Register SEP-1686 task protocol handlers with SDK.
Expand Down
2 changes: 2 additions & 0 deletions src/fastmcp/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def __init__(
self._lifespan = cast(LifespanCallable[LifespanResultT], default_lifespan)
self._lifespan_result: LifespanResultT | None = None
self._lifespan_result_set: bool = False
self._lifespan_ref_count: int = 0
self._lifespan_lock: asyncio.Lock = asyncio.Lock()
self._started: asyncio.Event = asyncio.Event()

# Generate random ID if no name provided
Expand Down
38 changes: 38 additions & 0 deletions tests/server/test_server_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,44 @@ def get_value() -> str:
# when the client session closes
assert lifespan_events == ["enter", "exit"]

async def test_server_lifespan_overlapping_sessions(self):
"""Test that overlapping sessions keep lifespan active until all sessions close."""
lifespan_events: list[str] = []

resource_state = "missing"

@asynccontextmanager
async def server_lifespan(mcp: FastMCP) -> AsyncIterator[dict[str, Any]]:
nonlocal resource_state
lifespan_events.append("enter")
resource_state = "open"
try:
yield {"initialized": True}
finally:
resource_state = "closed"
lifespan_events.append("exit")

mcp = FastMCP("TestServer", lifespan=server_lifespan)

@mcp.tool
def get_resource_state() -> str:
return resource_state

async with Client(mcp) as client1:
result1 = await client1.call_tool("get_resource_state", {})
assert result1.data == "open"

async with Client(mcp) as client2:
result2 = await client2.call_tool("get_resource_state", {})
assert result2.data == "open"

# client2 exited while client1 is still active; lifespan should remain open
result3 = await client1.call_tool("get_resource_state", {})
assert result3.data == "open"
assert lifespan_events == ["enter"]

assert lifespan_events == ["enter", "exit"]

async def test_server_lifespan_context_available(self):
"""Test that server_lifespan context is available to tools."""

Expand Down
Loading