Skip to content

Commit 46d55bd

Browse files
talalryzclaude
andauthored
fix: Add response_type + PKCE parameters to OAuth authorization endpoint (#15720)
* fix: Add response_type parameter to OAuth authorization endpoint Fixes #15684 OAuth providers like Google require the response_type parameter during the authorization flow. This commit adds response_type=code to the authorization redirect parameters, which is required by the OAuth 2.0 specification (RFC 6749 Section 4.1.1). Changes: - Added response_type=code to authorization params in discoverable_endpoints.py - Added test coverage for the response_type parameter 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> * fix oauth flow by forwarding code_challenge and forwarding code_verifier --------- Co-authored-by: Claude <[email protected]>
1 parent 98f1d63 commit 46d55bd

File tree

2 files changed

+319
-37
lines changed

2 files changed

+319
-37
lines changed

litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py

Lines changed: 91 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Optional, Tuple
2+
from typing import Optional
33
from urllib.parse import urlencode, urlparse, urlunparse
44

55
from fastapi import APIRouter, Form, HTTPException, Request
@@ -19,32 +19,47 @@
1919
)
2020

2121

22-
def encode_state_with_base_url(base_url: str, original_state: str) -> str:
22+
def encode_state_with_base_url(
23+
base_url: str,
24+
original_state: str,
25+
code_challenge: Optional[str] = None,
26+
code_challenge_method: Optional[str] = None,
27+
client_redirect_uri: Optional[str] = None,
28+
) -> str:
2329
"""
24-
Encode the base_url and original state using encryption.
30+
Encode the base_url, original state, and PKCE parameters using encryption.
2531
2632
Args:
2733
base_url: The base URL to encode
2834
original_state: The original state parameter
35+
code_challenge: PKCE code challenge from client
36+
code_challenge_method: PKCE code challenge method from client
37+
client_redirect_uri: Original redirect_uri from client
2938
3039
Returns:
31-
An encrypted string that encodes both values
40+
An encrypted string that encodes all values
3241
"""
33-
state_data = {"base_url": base_url, "original_state": original_state}
42+
state_data = {
43+
"base_url": base_url,
44+
"original_state": original_state,
45+
"code_challenge": code_challenge,
46+
"code_challenge_method": code_challenge_method,
47+
"client_redirect_uri": client_redirect_uri,
48+
}
3449
state_json = json.dumps(state_data, sort_keys=True)
3550
encrypted_state = encrypt_value_helper(state_json)
3651
return encrypted_state
3752

3853

39-
def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
54+
def decode_state_hash(encrypted_state: str) -> dict:
4055
"""
41-
Decode an encrypted state to retrieve the base_url and original state.
56+
Decode an encrypted state to retrieve all OAuth session data.
4257
4358
Args:
4459
encrypted_state: The encrypted string to decode
4560
4661
Returns:
47-
A tuple of (base_url, original_state)
62+
A dict containing base_url, original_state, and optional PKCE parameters
4863
4964
Raises:
5065
Exception: If decryption fails or data is malformed
@@ -54,7 +69,7 @@ def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
5469
raise ValueError("Failed to decrypt state parameter")
5570

5671
state_data = json.loads(decrypted_json)
57-
return state_data["base_url"], state_data["original_state"]
72+
return state_data
5873

5974

6075
@router.get("/{mcp_server_name}/authorize")
@@ -65,8 +80,12 @@ async def authorize(
6580
redirect_uri: str,
6681
state: str = "",
6782
mcp_server_name: Optional[str] = None,
83+
code_challenge: Optional[str] = None,
84+
code_challenge_method: Optional[str] = None,
85+
response_type: Optional[str] = None,
86+
scope: Optional[str] = None,
6887
):
69-
# Redirect to real GitHub OAuth
88+
# Redirect to real OAuth provider with PKCE support
7089
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
7190
global_mcp_server_manager,
7291
)
@@ -90,15 +109,30 @@ async def authorize(
90109
base_url = urlunparse(parsed._replace(query=""))
91110
request_base_url = str(request.base_url).rstrip("/")
92111

93-
# Encode the base_url and original state in a unique hash
94-
encoded_state = encode_state_with_base_url(base_url, state)
112+
# Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
113+
encoded_state = encode_state_with_base_url(
114+
base_url=base_url,
115+
original_state=state,
116+
code_challenge=code_challenge,
117+
code_challenge_method=code_challenge_method,
118+
client_redirect_uri=redirect_uri,
119+
)
95120

121+
# Build params for upstream OAuth provider
96122
params = {
97123
"client_id": mcp_server.client_id,
98124
"redirect_uri": f"{request_base_url}/callback",
99-
"scope": " ".join(mcp_server.scopes),
125+
"scope": scope or " ".join(mcp_server.scopes),
100126
"state": encoded_state,
127+
"response_type": response_type or "code",
101128
}
129+
130+
# Forward PKCE parameters if present
131+
if code_challenge:
132+
params["code_challenge"] = code_challenge
133+
if code_challenge_method:
134+
params["code_challenge_method"] = code_challenge_method
135+
102136
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
103137

104138

@@ -110,15 +144,16 @@ async def token_endpoint(
110144
redirect_uri: str = Form(None),
111145
client_id: str = Form(...),
112146
client_secret: str = Form(...),
147+
code_verifier: str = Form(None),
113148
):
114149
"""
115-
Accept the authorization code from Claude and exchange it for GitHub token.
116-
Forward the GitHub token back to Claude in standard OAuth format.
150+
Accept the authorization code from client and exchange it for OAuth token.
151+
Supports PKCE flow by forwarding code_verifier to upstream provider.
117152
118-
1. Call the token endpoint
119-
2. Store the user's PAT in the db - and generate a LiteLLM virtual key
120-
2. Return the token
121-
3. Return a virtual key in this response
153+
1. Call the token endpoint with PKCE parameters
154+
2. Store the user's token in the db - and generate a LiteLLM virtual key
155+
3. Return the token
156+
4. Return a virtual key in this response
122157
"""
123158
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
124159
global_mcp_server_manager,
@@ -136,41 +171,60 @@ async def token_endpoint(
136171

137172
proxy_base_url = str(request.base_url).rstrip("/")
138173

139-
# Exchange code for real GitHub token
174+
# Build token request data
175+
token_data = {
176+
"grant_type": "authorization_code",
177+
"client_id": mcp_server.client_id,
178+
"client_secret": mcp_server.client_secret,
179+
"code": code,
180+
"redirect_uri": f"{proxy_base_url}/callback",
181+
}
182+
183+
# Forward PKCE code_verifier if present
184+
if code_verifier:
185+
token_data["code_verifier"] = code_verifier
186+
187+
# Exchange code for real OAuth token
140188
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
141189
response = await async_client.post(
142190
mcp_server.token_url,
143191
headers={"Accept": "application/json"},
144-
data={
145-
"client_id": mcp_server.client_id,
146-
"client_secret": mcp_server.client_secret,
147-
"code": code,
148-
"redirect_uri": f"{proxy_base_url}/callback",
149-
},
192+
data=token_data,
150193
)
151194

152195
response.raise_for_status()
153-
github_token = response.json()["access_token"]
154-
155-
# Return to Claude in expected OAuth 2 format
196+
token_response = response.json()
197+
access_token = token_response["access_token"]
198+
199+
# Return to client in expected OAuth 2 format
200+
# Only include fields that have values
201+
result = {
202+
"access_token": access_token,
203+
"token_type": token_response.get("token_type", "Bearer"),
204+
"expires_in": token_response.get("expires_in", 3600),
205+
}
156206

157-
### return a virtual key in this response
207+
# Add optional fields only if they exist
208+
if "refresh_token" in token_response and token_response["refresh_token"]:
209+
result["refresh_token"] = token_response["refresh_token"]
210+
if "scope" in token_response and token_response["scope"]:
211+
result["scope"] = token_response["scope"]
158212

159-
return JSONResponse(
160-
{"access_token": github_token, "token_type": "Bearer", "expires_in": 3600}
161-
)
213+
return JSONResponse(result)
162214

163215

164216
@router.get("/callback")
165217
async def callback(code: str, state: str):
166218
try:
167-
# Decode the state hash to get base_url and original state
168-
base_url, original_state = decode_state_hash(state)
219+
# Decode the state hash to get base_url, original state, and PKCE params
220+
state_data = decode_state_hash(state)
221+
base_url = state_data["base_url"]
222+
original_state = state_data["original_state"]
169223

170-
# Exchange code for token with GitHub
224+
# Forward code and original state back to client
171225
params = {"code": code, "state": original_state}
172226

173-
# Forward token to Claude ephemeral endpoint
227+
# Forward to client's callback endpoint
174228
complete_returned_url = f"{base_url}?{urlencode(params)}"
175229
return RedirectResponse(url=complete_returned_url, status_code=302)
176230

0 commit comments

Comments
 (0)