Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions src/fastmcp/client/oauth_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def create_oauth_callback_server(
Configured uvicorn Server instance (not yet running)
"""

def store_result_once(
*,
code: str | None = None,
state: str | None = None,
error: Exception | None = None,
) -> None:
"""Store the first callback result and ignore subsequent requests."""
if result_container is None or result_ready is None or result_ready.is_set():
return

result_container.code = code
result_container.state = state
result_container.error = error
result_ready.set()

async def callback_handler(request: Request):
"""Handle OAuth callback requests with proper HTML responses."""
query_params = dict(request.query_params)
Expand All @@ -136,9 +151,7 @@ async def callback_handler(request: Request):
user_message = f"Authorization failed: {error_desc}"

# Store error and signal completion if result tracking provided
if result_container is not None and result_ready is not None:
result_container.error = RuntimeError(user_message)
result_ready.set()
store_result_once(error=RuntimeError(user_message))

return create_secure_html_response(
create_callback_html(
Expand All @@ -152,9 +165,7 @@ async def callback_handler(request: Request):
user_message = "No authorization code was received from the server."

# Store error and signal completion if result tracking provided
if result_container is not None and result_ready is not None:
result_container.error = RuntimeError(user_message)
result_ready.set()
store_result_once(error=RuntimeError(user_message))

return create_secure_html_response(
create_callback_html(
Expand All @@ -171,9 +182,7 @@ async def callback_handler(request: Request):
)

# Store error and signal completion if result tracking provided
if result_container is not None and result_ready is not None:
result_container.error = RuntimeError(user_message)
result_ready.set()
store_result_once(error=RuntimeError(user_message))

return create_secure_html_response(
create_callback_html(
Expand All @@ -184,10 +193,10 @@ async def callback_handler(request: Request):
)

# Success case - store result and signal completion if result tracking provided
if result_container is not None and result_ready is not None:
result_container.code = callback_response.code
result_container.state = callback_response.state
result_ready.set()
store_result_once(
code=callback_response.code,
state=callback_response.state,
)

return create_secure_html_response(
create_callback_html("", is_success=True, server_url=server_url)
Expand Down
44 changes: 44 additions & 0 deletions tests/client/test_oauth_callback_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import anyio
import httpx

from fastmcp.client.oauth_callback import (
OAuthCallbackResult,
create_oauth_callback_server,
)
from fastmcp.utilities.http import find_available_port


async def test_oauth_callback_result_ignores_subsequent_callbacks():
"""Only the first callback should be captured in shared OAuth callback state."""
port = find_available_port()
result = OAuthCallbackResult()
result_ready = anyio.Event()
server = create_oauth_callback_server(
port=port,
result_container=result,
result_ready=result_ready,
)

async with anyio.create_task_group() as tg:
tg.start_soon(server.serve)

await anyio.sleep(0.05)

async with httpx.AsyncClient() as client:
first = await client.get(
f"http://127.0.0.1:{port}/callback?code=good&state=s1"
)
assert first.status_code == 200

await result_ready.wait()

second = await client.get(
f"http://127.0.0.1:{port}/callback?code=evil&state=s2"
)
assert second.status_code == 200

assert result.error is None
assert result.code == "good"
assert result.state == "s1"

tg.cancel_scope.cancel()
Loading