diff --git a/src/fastmcp/server/mixins/lifespan.py b/src/fastmcp/server/mixins/lifespan.py index 267736a83..0f2589272 100644 --- a/src/fastmcp/server/mixins/lifespan.py +++ b/src/fastmcp/server/mixins/lifespan.py @@ -136,30 +136,49 @@ async def _docket_lifespan(self: FastMCP) -> AsyncIterator[None]: @asynccontextmanager async def _lifespan_manager(self: FastMCP) -> AsyncIterator[None]: - if self._lifespan_result_set: - yield + async with self._lifespan_lock: + if self._lifespan_result_set: + self._lifespan_ref_count += 1 + should_enter_lifespan = False + else: + self._lifespan_ref_count = 1 + should_enter_lifespan = True + + if not should_enter_lifespan: + try: + yield + finally: + async with self._lifespan_lock: + self._lifespan_ref_count -= 1 + if self._lifespan_ref_count == 0: + self._lifespan_result_set = False + self._lifespan_result = None return - async with ( - self._lifespan(self) as user_lifespan_result, - self._docket_lifespan(), - ): - self._lifespan_result = user_lifespan_result - self._lifespan_result_set = True - - async with AsyncExitStack[bool | None]() as stack: - # Start lifespans for all providers - for provider in self.providers: - await stack.enter_async_context(provider.lifespan()) - - self._started.set() - try: - yield - finally: - self._started.clear() - - self._lifespan_result_set = False - self._lifespan_result = None + try: + async with ( + self._lifespan(self) as user_lifespan_result, + self._docket_lifespan(), + ): + self._lifespan_result = user_lifespan_result + self._lifespan_result_set = True + + async with AsyncExitStack[bool | None]() as stack: + # Start lifespans for all providers + for provider in self.providers: + await stack.enter_async_context(provider.lifespan()) + + self._started.set() + try: + yield + finally: + self._started.clear() + finally: + async with self._lifespan_lock: + self._lifespan_ref_count -= 1 + if self._lifespan_ref_count == 0: + self._lifespan_result_set = False + self._lifespan_result = None def _setup_task_protocol_handlers(self: FastMCP) -> None: """Register SEP-1686 task protocol handlers with SDK. diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 0dc47f143..a55aa8187 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -300,6 +300,8 @@ def __init__( self._lifespan = cast(LifespanCallable[LifespanResultT], default_lifespan) self._lifespan_result: LifespanResultT | None = None self._lifespan_result_set: bool = False + self._lifespan_ref_count: int = 0 + self._lifespan_lock: asyncio.Lock = asyncio.Lock() self._started: asyncio.Event = asyncio.Event() # Generate random ID if no name provided diff --git a/tests/server/test_server_lifespan.py b/tests/server/test_server_lifespan.py index ef52e4e22..2d2b3446e 100644 --- a/tests/server/test_server_lifespan.py +++ b/tests/server/test_server_lifespan.py @@ -54,6 +54,44 @@ def get_value() -> str: # when the client session closes assert lifespan_events == ["enter", "exit"] + async def test_server_lifespan_overlapping_sessions(self): + """Test that overlapping sessions keep lifespan active until all sessions close.""" + lifespan_events: list[str] = [] + + resource_state = "missing" + + @asynccontextmanager + async def server_lifespan(mcp: FastMCP) -> AsyncIterator[dict[str, Any]]: + nonlocal resource_state + lifespan_events.append("enter") + resource_state = "open" + try: + yield {"initialized": True} + finally: + resource_state = "closed" + lifespan_events.append("exit") + + mcp = FastMCP("TestServer", lifespan=server_lifespan) + + @mcp.tool + def get_resource_state() -> str: + return resource_state + + async with Client(mcp) as client1: + result1 = await client1.call_tool("get_resource_state", {}) + assert result1.data == "open" + + async with Client(mcp) as client2: + result2 = await client2.call_tool("get_resource_state", {}) + assert result2.data == "open" + + # client2 exited while client1 is still active; lifespan should remain open + result3 = await client1.call_tool("get_resource_state", {}) + assert result3.data == "open" + assert lifespan_events == ["enter"] + + assert lifespan_events == ["enter", "exit"] + async def test_server_lifespan_context_available(self): """Test that server_lifespan context is available to tools."""