Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 91 additions & 37 deletions litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional, Tuple
from typing import Optional
from urllib.parse import urlencode, urlparse, urlunparse

from fastapi import APIRouter, Form, HTTPException, Request
Expand All @@ -19,32 +19,47 @@
)


def encode_state_with_base_url(base_url: str, original_state: str) -> str:
def encode_state_with_base_url(
base_url: str,
original_state: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
client_redirect_uri: Optional[str] = None,
) -> str:
"""
Encode the base_url and original state using encryption.
Encode the base_url, original state, and PKCE parameters using encryption.

Args:
base_url: The base URL to encode
original_state: The original state parameter
code_challenge: PKCE code challenge from client
code_challenge_method: PKCE code challenge method from client
client_redirect_uri: Original redirect_uri from client

Returns:
An encrypted string that encodes both values
An encrypted string that encodes all values
"""
state_data = {"base_url": base_url, "original_state": original_state}
state_data = {
"base_url": base_url,
"original_state": original_state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"client_redirect_uri": client_redirect_uri,
}
state_json = json.dumps(state_data, sort_keys=True)
encrypted_state = encrypt_value_helper(state_json)
return encrypted_state


def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
def decode_state_hash(encrypted_state: str) -> dict:
"""
Decode an encrypted state to retrieve the base_url and original state.
Decode an encrypted state to retrieve all OAuth session data.

Args:
encrypted_state: The encrypted string to decode

Returns:
A tuple of (base_url, original_state)
A dict containing base_url, original_state, and optional PKCE parameters

Raises:
Exception: If decryption fails or data is malformed
Expand All @@ -54,7 +69,7 @@ def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
raise ValueError("Failed to decrypt state parameter")

state_data = json.loads(decrypted_json)
return state_data["base_url"], state_data["original_state"]
return state_data


@router.get("/{mcp_server_name}/authorize")
Expand All @@ -65,8 +80,12 @@ async def authorize(
redirect_uri: str,
state: str = "",
mcp_server_name: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
# Redirect to real GitHub OAuth
# Redirect to real OAuth provider with PKCE support
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
Expand All @@ -90,15 +109,30 @@ async def authorize(
base_url = urlunparse(parsed._replace(query=""))
request_base_url = str(request.base_url).rstrip("/")

# Encode the base_url and original state in a unique hash
encoded_state = encode_state_with_base_url(base_url, state)
# Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
encoded_state = encode_state_with_base_url(
base_url=base_url,
original_state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
client_redirect_uri=redirect_uri,
)

# Build params for upstream OAuth provider
params = {
"client_id": mcp_server.client_id,
"redirect_uri": f"{request_base_url}/callback",
"scope": " ".join(mcp_server.scopes),
"scope": scope or " ".join(mcp_server.scopes),
"state": encoded_state,
"response_type": response_type or "code",
}

# Forward PKCE parameters if present
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method

return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")


Expand All @@ -110,15 +144,16 @@ async def token_endpoint(
redirect_uri: str = Form(None),
client_id: str = Form(...),
client_secret: str = Form(...),
code_verifier: str = Form(None),
):
"""
Accept the authorization code from Claude and exchange it for GitHub token.
Forward the GitHub token back to Claude in standard OAuth format.
Accept the authorization code from client and exchange it for OAuth token.
Supports PKCE flow by forwarding code_verifier to upstream provider.

1. Call the token endpoint
2. Store the user's PAT in the db - and generate a LiteLLM virtual key
2. Return the token
3. Return a virtual key in this response
1. Call the token endpoint with PKCE parameters
2. Store the user's token in the db - and generate a LiteLLM virtual key
3. Return the token
4. Return a virtual key in this response
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
Expand All @@ -136,41 +171,60 @@ async def token_endpoint(

proxy_base_url = str(request.base_url).rstrip("/")

# Exchange code for real GitHub token
# Build token request data
token_data = {
"grant_type": "authorization_code",
"client_id": mcp_server.client_id,
"client_secret": mcp_server.client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
}

# Forward PKCE code_verifier if present
if code_verifier:
token_data["code_verifier"] = code_verifier

# Exchange code for real OAuth token
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
response = await async_client.post(
mcp_server.token_url,
headers={"Accept": "application/json"},
data={
"client_id": mcp_server.client_id,
"client_secret": mcp_server.client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
},
data=token_data,
)

response.raise_for_status()
github_token = response.json()["access_token"]

# Return to Claude in expected OAuth 2 format
token_response = response.json()
access_token = token_response["access_token"]

# Return to client in expected OAuth 2 format
# Only include fields that have values
result = {
"access_token": access_token,
"token_type": token_response.get("token_type", "Bearer"),
"expires_in": token_response.get("expires_in", 3600),
}

### return a virtual key in this response
# Add optional fields only if they exist
if "refresh_token" in token_response and token_response["refresh_token"]:
result["refresh_token"] = token_response["refresh_token"]
if "scope" in token_response and token_response["scope"]:
result["scope"] = token_response["scope"]

return JSONResponse(
{"access_token": github_token, "token_type": "Bearer", "expires_in": 3600}
)
return JSONResponse(result)


@router.get("/callback")
async def callback(code: str, state: str):
try:
# Decode the state hash to get base_url and original state
base_url, original_state = decode_state_hash(state)
# Decode the state hash to get base_url, original state, and PKCE params
state_data = decode_state_hash(state)
base_url = state_data["base_url"]
original_state = state_data["original_state"]

# Exchange code for token with GitHub
# Forward code and original state back to client
params = {"code": code, "state": original_state}

# Forward token to Claude ephemeral endpoint
# Forward to client's callback endpoint
complete_returned_url = f"{base_url}?{urlencode(params)}"
return RedirectResponse(url=complete_returned_url, status_code=302)

Expand Down
Loading
Loading