diff --git a/src/fastmcp/client/oauth_callback.py b/src/fastmcp/client/oauth_callback.py index ba5c93d54..6bfba59cb 100644 --- a/src/fastmcp/client/oauth_callback.py +++ b/src/fastmcp/client/oauth_callback.py @@ -121,6 +121,21 @@ def create_oauth_callback_server( Configured uvicorn Server instance (not yet running) """ + def store_result_once( + *, + code: str | None = None, + state: str | None = None, + error: Exception | None = None, + ) -> None: + """Store the first callback result and ignore subsequent requests.""" + if result_container is None or result_ready is None or result_ready.is_set(): + return + + result_container.code = code + result_container.state = state + result_container.error = error + result_ready.set() + async def callback_handler(request: Request): """Handle OAuth callback requests with proper HTML responses.""" query_params = dict(request.query_params) @@ -136,9 +151,7 @@ async def callback_handler(request: Request): user_message = f"Authorization failed: {error_desc}" # Store error and signal completion if result tracking provided - if result_container is not None and result_ready is not None: - result_container.error = RuntimeError(user_message) - result_ready.set() + store_result_once(error=RuntimeError(user_message)) return create_secure_html_response( create_callback_html( @@ -152,9 +165,7 @@ async def callback_handler(request: Request): user_message = "No authorization code was received from the server." # Store error and signal completion if result tracking provided - if result_container is not None and result_ready is not None: - result_container.error = RuntimeError(user_message) - result_ready.set() + store_result_once(error=RuntimeError(user_message)) return create_secure_html_response( create_callback_html( @@ -171,9 +182,7 @@ async def callback_handler(request: Request): ) # Store error and signal completion if result tracking provided - if result_container is not None and result_ready is not None: - result_container.error = RuntimeError(user_message) - result_ready.set() + store_result_once(error=RuntimeError(user_message)) return create_secure_html_response( create_callback_html( @@ -184,10 +193,10 @@ async def callback_handler(request: Request): ) # Success case - store result and signal completion if result tracking provided - if result_container is not None and result_ready is not None: - result_container.code = callback_response.code - result_container.state = callback_response.state - result_ready.set() + store_result_once( + code=callback_response.code, + state=callback_response.state, + ) return create_secure_html_response( create_callback_html("", is_success=True, server_url=server_url) diff --git a/tests/client/test_oauth_callback_race.py b/tests/client/test_oauth_callback_race.py new file mode 100644 index 000000000..7ebcb4534 --- /dev/null +++ b/tests/client/test_oauth_callback_race.py @@ -0,0 +1,44 @@ +import anyio +import httpx + +from fastmcp.client.oauth_callback import ( + OAuthCallbackResult, + create_oauth_callback_server, +) +from fastmcp.utilities.http import find_available_port + + +async def test_oauth_callback_result_ignores_subsequent_callbacks(): + """Only the first callback should be captured in shared OAuth callback state.""" + port = find_available_port() + result = OAuthCallbackResult() + result_ready = anyio.Event() + server = create_oauth_callback_server( + port=port, + result_container=result, + result_ready=result_ready, + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(server.serve) + + await anyio.sleep(0.05) + + async with httpx.AsyncClient() as client: + first = await client.get( + f"http://127.0.0.1:{port}/callback?code=good&state=s1" + ) + assert first.status_code == 200 + + await result_ready.wait() + + second = await client.get( + f"http://127.0.0.1:{port}/callback?code=evil&state=s2" + ) + assert second.status_code == 200 + + assert result.error is None + assert result.code == "good" + assert result.state == "s1" + + tg.cancel_scope.cancel()