diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index 5e5099426a06..de4f6fafc3cb 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Tuple +from typing import Optional from urllib.parse import urlencode, urlparse, urlunparse from fastapi import APIRouter, Form, HTTPException, Request @@ -19,32 +19,47 @@ ) -def encode_state_with_base_url(base_url: str, original_state: str) -> str: +def encode_state_with_base_url( + base_url: str, + original_state: str, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[str] = None, + client_redirect_uri: Optional[str] = None, +) -> str: """ - Encode the base_url and original state using encryption. + Encode the base_url, original state, and PKCE parameters using encryption. Args: base_url: The base URL to encode original_state: The original state parameter + code_challenge: PKCE code challenge from client + code_challenge_method: PKCE code challenge method from client + client_redirect_uri: Original redirect_uri from client Returns: - An encrypted string that encodes both values + An encrypted string that encodes all values """ - state_data = {"base_url": base_url, "original_state": original_state} + state_data = { + "base_url": base_url, + "original_state": original_state, + "code_challenge": code_challenge, + "code_challenge_method": code_challenge_method, + "client_redirect_uri": client_redirect_uri, + } state_json = json.dumps(state_data, sort_keys=True) encrypted_state = encrypt_value_helper(state_json) return encrypted_state -def decode_state_hash(encrypted_state: str) -> Tuple[str, str]: +def decode_state_hash(encrypted_state: str) -> dict: """ - Decode an encrypted state to retrieve the base_url and original state. + Decode an encrypted state to retrieve all OAuth session data. Args: encrypted_state: The encrypted string to decode Returns: - A tuple of (base_url, original_state) + A dict containing base_url, original_state, and optional PKCE parameters Raises: Exception: If decryption fails or data is malformed @@ -54,7 +69,7 @@ def decode_state_hash(encrypted_state: str) -> Tuple[str, str]: raise ValueError("Failed to decrypt state parameter") state_data = json.loads(decrypted_json) - return state_data["base_url"], state_data["original_state"] + return state_data @router.get("/{mcp_server_name}/authorize") @@ -65,8 +80,12 @@ async def authorize( redirect_uri: str, state: str = "", mcp_server_name: Optional[str] = None, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[str] = None, + response_type: Optional[str] = None, + scope: Optional[str] = None, ): - # Redirect to real GitHub OAuth + # Redirect to real OAuth provider with PKCE support from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( global_mcp_server_manager, ) @@ -90,15 +109,30 @@ async def authorize( base_url = urlunparse(parsed._replace(query="")) request_base_url = str(request.base_url).rstrip("/") - # Encode the base_url and original state in a unique hash - encoded_state = encode_state_with_base_url(base_url, state) + # Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state + encoded_state = encode_state_with_base_url( + base_url=base_url, + original_state=state, + code_challenge=code_challenge, + code_challenge_method=code_challenge_method, + client_redirect_uri=redirect_uri, + ) + # Build params for upstream OAuth provider params = { "client_id": mcp_server.client_id, "redirect_uri": f"{request_base_url}/callback", - "scope": " ".join(mcp_server.scopes), + "scope": scope or " ".join(mcp_server.scopes), "state": encoded_state, + "response_type": response_type or "code", } + + # Forward PKCE parameters if present + if code_challenge: + params["code_challenge"] = code_challenge + if code_challenge_method: + params["code_challenge_method"] = code_challenge_method + return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}") @@ -110,15 +144,16 @@ async def token_endpoint( redirect_uri: str = Form(None), client_id: str = Form(...), client_secret: str = Form(...), + code_verifier: str = Form(None), ): """ - Accept the authorization code from Claude and exchange it for GitHub token. - Forward the GitHub token back to Claude in standard OAuth format. + Accept the authorization code from client and exchange it for OAuth token. + Supports PKCE flow by forwarding code_verifier to upstream provider. - 1. Call the token endpoint - 2. Store the user's PAT in the db - and generate a LiteLLM virtual key - 2. Return the token - 3. Return a virtual key in this response + 1. Call the token endpoint with PKCE parameters + 2. Store the user's token in the db - and generate a LiteLLM virtual key + 3. Return the token + 4. Return a virtual key in this response """ from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( global_mcp_server_manager, @@ -136,41 +171,60 @@ async def token_endpoint( proxy_base_url = str(request.base_url).rstrip("/") - # Exchange code for real GitHub token + # Build token request data + token_data = { + "grant_type": "authorization_code", + "client_id": mcp_server.client_id, + "client_secret": mcp_server.client_secret, + "code": code, + "redirect_uri": f"{proxy_base_url}/callback", + } + + # Forward PKCE code_verifier if present + if code_verifier: + token_data["code_verifier"] = code_verifier + + # Exchange code for real OAuth token async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check) response = await async_client.post( mcp_server.token_url, headers={"Accept": "application/json"}, - data={ - "client_id": mcp_server.client_id, - "client_secret": mcp_server.client_secret, - "code": code, - "redirect_uri": f"{proxy_base_url}/callback", - }, + data=token_data, ) response.raise_for_status() - github_token = response.json()["access_token"] - - # Return to Claude in expected OAuth 2 format + token_response = response.json() + access_token = token_response["access_token"] + + # Return to client in expected OAuth 2 format + # Only include fields that have values + result = { + "access_token": access_token, + "token_type": token_response.get("token_type", "Bearer"), + "expires_in": token_response.get("expires_in", 3600), + } - ### return a virtual key in this response + # Add optional fields only if they exist + if "refresh_token" in token_response and token_response["refresh_token"]: + result["refresh_token"] = token_response["refresh_token"] + if "scope" in token_response and token_response["scope"]: + result["scope"] = token_response["scope"] - return JSONResponse( - {"access_token": github_token, "token_type": "Bearer", "expires_in": 3600} - ) + return JSONResponse(result) @router.get("/callback") async def callback(code: str, state: str): try: - # Decode the state hash to get base_url and original state - base_url, original_state = decode_state_hash(state) + # Decode the state hash to get base_url, original state, and PKCE params + state_data = decode_state_hash(state) + base_url = state_data["base_url"] + original_state = state_data["original_state"] - # Exchange code for token with GitHub + # Forward code and original state back to client params = {"code": code, "state": original_state} - # Forward token to Claude ephemeral endpoint + # Forward to client's callback endpoint complete_returned_url = f"{base_url}?{urlencode(params)}" return RedirectResponse(url=complete_returned_url, status_code=302) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py new file mode 100644 index 000000000000..afa7627952ea --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -0,0 +1,228 @@ +"""Tests for MCP OAuth discoverable endpoints""" +import pytest +from unittest.mock import MagicMock, patch + + +@pytest.mark.asyncio +async def test_authorize_endpoint_includes_response_type(): + """Test that authorize endpoint includes response_type=code parameter (fixes #15684)""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + authorize, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + from litellm.proxy._types import MCPTransport + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + # Clear registry + global_mcp_server_manager.registry.clear() + + # Create mock OAuth2 server + oauth2_server = MCPServer( + server_id="test_oauth_server", + name="test_oauth", + server_name="test_oauth", + alias="test_oauth", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="test_client_id", + client_secret="test_client_secret", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + scopes=["read", "write"], + ) + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + # Mock request + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://litellm.example.com/" + mock_request.headers = {} + + # Mock the encryption functions to avoid needing a signing key + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper" + ) as mock_encrypt: + mock_encrypt.return_value = "mocked_encrypted_state" + + # Call authorize endpoint + response = await authorize( + request=mock_request, + client_id="test_oauth", + redirect_uri="https://client.example.com/callback", + state="test_state", + ) + + # Verify response is a redirect + assert response.status_code == 307 # FastAPI RedirectResponse default + + # Verify response_type is in the redirect URL + assert "response_type=code" in response.headers["location"] + assert "https://provider.com/oauth/authorize" in response.headers["location"] + assert "client_id=test_client_id" in response.headers["location"] + assert "scope=read+write" in response.headers["location"] + + +@pytest.mark.asyncio +async def test_authorize_endpoint_forwards_pkce_parameters(): + """Test that authorize endpoint forwards PKCE parameters (code_challenge and code_challenge_method)""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + authorize, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + from litellm.proxy._types import MCPTransport + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + # Clear registry + global_mcp_server_manager.registry.clear() + + # Create mock OAuth2 server (simulating Google OAuth) + oauth2_server = MCPServer( + server_id="google_mcp", + name="google_mcp", + server_name="google_mcp", + alias="google_mcp", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="669428968603-test.apps.googleusercontent.com", + client_secret="GOCSPX-test_secret", + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["https://www.googleapis.com/auth/drive", "openid", "email"], + ) + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + # Mock request + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://litellm-proxy.example.com/" + mock_request.headers = {} + + # Mock the encryption function + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper" + ) as mock_encrypt: + mock_encrypt.return_value = "mocked_encrypted_state_with_pkce" + + # Call authorize endpoint with PKCE parameters + response = await authorize( + request=mock_request, + client_id="google_mcp", + redirect_uri="http://localhost:60108/callback", + state="test_client_state", + code_challenge="x6YH_qgwbvOzbsHDuL1sW9gYkR9-gObUiIB5RkPwxDk", + code_challenge_method="S256", + ) + + # Verify response is a redirect + assert response.status_code == 307 + + # Verify PKCE parameters are included in the redirect URL + location = response.headers["location"] + assert "https://accounts.google.com/o/oauth2/v2/auth" in location + assert "code_challenge=x6YH_qgwbvOzbsHDuL1sW9gYkR9-gObUiIB5RkPwxDk" in location + assert "code_challenge_method=S256" in location + assert "client_id=669428968603-test.apps.googleusercontent.com" in location + assert "response_type=code" in location + + +@pytest.mark.asyncio +async def test_token_endpoint_forwards_code_verifier(): + """Test that token endpoint forwards code_verifier for PKCE flow""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + token_endpoint, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + from litellm.proxy._types import MCPTransport + from fastapi import Request + import httpx + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + # Clear registry + global_mcp_server_manager.registry.clear() + + # Create mock OAuth2 server + oauth2_server = MCPServer( + server_id="google_mcp", + name="google_mcp", + server_name="google_mcp", + alias="google_mcp", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="669428968603-test.apps.googleusercontent.com", + client_secret="GOCSPX-test_secret", + authorization_url="https://accounts.google.com/o/oauth2/v2/auth", + token_url="https://oauth2.googleapis.com/token", + scopes=["https://www.googleapis.com/auth/drive", "openid", "email"], + ) + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + # Mock request + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://litellm-proxy.example.com/" + + # Mock httpx client response + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "ya29.test_access_token", + "token_type": "Bearer", + "expires_in": 3599, + "scope": "openid email https://www.googleapis.com/auth/drive", + } + mock_response.raise_for_status = MagicMock() + + # Mock the async httpx client with AsyncMock for async methods + from unittest.mock import AsyncMock + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client" + ) as mock_get_client: + mock_async_client = MagicMock() + # Use AsyncMock for the async post method + mock_async_client.post = AsyncMock(return_value=mock_response) + mock_get_client.return_value = mock_async_client + + # Call token endpoint with code_verifier + response = await token_endpoint( + request=mock_request, + grant_type="authorization_code", + code="4/test_authorization_code", + redirect_uri="http://localhost:60108/callback", + client_id="google_mcp", + client_secret="dummy", + code_verifier="test_code_verifier_from_client", + ) + + # Verify that the token endpoint was called with code_verifier + mock_async_client.post.assert_called_once() + call_args = mock_async_client.post.call_args + + # Check the data parameter includes code_verifier + assert call_args[1]["data"]["code_verifier"] == "test_code_verifier_from_client" + assert call_args[1]["data"]["code"] == "4/test_authorization_code" + assert call_args[1]["data"]["client_id"] == "669428968603-test.apps.googleusercontent.com" + assert call_args[1]["data"]["client_secret"] == "GOCSPX-test_secret" + assert call_args[1]["data"]["grant_type"] == "authorization_code" + + # Verify response + response_data = response.body + import json + token_data = json.loads(response_data) + assert token_data["access_token"] == "ya29.test_access_token" + assert token_data["token_type"] == "Bearer"