11import json
2- from typing import Optional , Tuple
2+ from typing import Optional
33from urllib .parse import urlencode , urlparse , urlunparse
44
55from fastapi import APIRouter , Form , HTTPException , Request
1919)
2020
2121
22- def encode_state_with_base_url (base_url : str , original_state : str ) -> str :
22+ def encode_state_with_base_url (
23+ base_url : str ,
24+ original_state : str ,
25+ code_challenge : Optional [str ] = None ,
26+ code_challenge_method : Optional [str ] = None ,
27+ client_redirect_uri : Optional [str ] = None ,
28+ ) -> str :
2329 """
24- Encode the base_url and original state using encryption.
30+ Encode the base_url, original state, and PKCE parameters using encryption.
2531
2632 Args:
2733 base_url: The base URL to encode
2834 original_state: The original state parameter
35+ code_challenge: PKCE code challenge from client
36+ code_challenge_method: PKCE code challenge method from client
37+ client_redirect_uri: Original redirect_uri from client
2938
3039 Returns:
31- An encrypted string that encodes both values
40+ An encrypted string that encodes all values
3241 """
33- state_data = {"base_url" : base_url , "original_state" : original_state }
42+ state_data = {
43+ "base_url" : base_url ,
44+ "original_state" : original_state ,
45+ "code_challenge" : code_challenge ,
46+ "code_challenge_method" : code_challenge_method ,
47+ "client_redirect_uri" : client_redirect_uri ,
48+ }
3449 state_json = json .dumps (state_data , sort_keys = True )
3550 encrypted_state = encrypt_value_helper (state_json )
3651 return encrypted_state
3752
3853
39- def decode_state_hash (encrypted_state : str ) -> Tuple [ str , str ] :
54+ def decode_state_hash (encrypted_state : str ) -> dict :
4055 """
41- Decode an encrypted state to retrieve the base_url and original state .
56+ Decode an encrypted state to retrieve all OAuth session data .
4257
4358 Args:
4459 encrypted_state: The encrypted string to decode
4560
4661 Returns:
47- A tuple of ( base_url, original_state)
62+ A dict containing base_url, original_state, and optional PKCE parameters
4863
4964 Raises:
5065 Exception: If decryption fails or data is malformed
@@ -54,7 +69,7 @@ def decode_state_hash(encrypted_state: str) -> Tuple[str, str]:
5469 raise ValueError ("Failed to decrypt state parameter" )
5570
5671 state_data = json .loads (decrypted_json )
57- return state_data [ "base_url" ], state_data [ "original_state" ]
72+ return state_data
5873
5974
6075@router .get ("/{mcp_server_name}/authorize" )
@@ -65,8 +80,12 @@ async def authorize(
6580 redirect_uri : str ,
6681 state : str = "" ,
6782 mcp_server_name : Optional [str ] = None ,
83+ code_challenge : Optional [str ] = None ,
84+ code_challenge_method : Optional [str ] = None ,
85+ response_type : Optional [str ] = None ,
86+ scope : Optional [str ] = None ,
6887):
69- # Redirect to real GitHub OAuth
88+ # Redirect to real OAuth provider with PKCE support
7089 from litellm .proxy ._experimental .mcp_server .mcp_server_manager import (
7190 global_mcp_server_manager ,
7291 )
@@ -90,15 +109,30 @@ async def authorize(
90109 base_url = urlunparse (parsed ._replace (query = "" ))
91110 request_base_url = str (request .base_url ).rstrip ("/" )
92111
93- # Encode the base_url and original state in a unique hash
94- encoded_state = encode_state_with_base_url (base_url , state )
112+ # Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
113+ encoded_state = encode_state_with_base_url (
114+ base_url = base_url ,
115+ original_state = state ,
116+ code_challenge = code_challenge ,
117+ code_challenge_method = code_challenge_method ,
118+ client_redirect_uri = redirect_uri ,
119+ )
95120
121+ # Build params for upstream OAuth provider
96122 params = {
97123 "client_id" : mcp_server .client_id ,
98124 "redirect_uri" : f"{ request_base_url } /callback" ,
99- "scope" : " " .join (mcp_server .scopes ),
125+ "scope" : scope or " " .join (mcp_server .scopes ),
100126 "state" : encoded_state ,
127+ "response_type" : response_type or "code" ,
101128 }
129+
130+ # Forward PKCE parameters if present
131+ if code_challenge :
132+ params ["code_challenge" ] = code_challenge
133+ if code_challenge_method :
134+ params ["code_challenge_method" ] = code_challenge_method
135+
102136 return RedirectResponse (f"{ mcp_server .authorization_url } ?{ urlencode (params )} " )
103137
104138
@@ -110,15 +144,16 @@ async def token_endpoint(
110144 redirect_uri : str = Form (None ),
111145 client_id : str = Form (...),
112146 client_secret : str = Form (...),
147+ code_verifier : str = Form (None ),
113148):
114149 """
115- Accept the authorization code from Claude and exchange it for GitHub token.
116- Forward the GitHub token back to Claude in standard OAuth format .
150+ Accept the authorization code from client and exchange it for OAuth token.
151+ Supports PKCE flow by forwarding code_verifier to upstream provider .
117152
118- 1. Call the token endpoint
119- 2. Store the user's PAT in the db - and generate a LiteLLM virtual key
120- 2 . Return the token
121- 3 . Return a virtual key in this response
153+ 1. Call the token endpoint with PKCE parameters
154+ 2. Store the user's token in the db - and generate a LiteLLM virtual key
155+ 3 . Return the token
156+ 4 . Return a virtual key in this response
122157 """
123158 from litellm .proxy ._experimental .mcp_server .mcp_server_manager import (
124159 global_mcp_server_manager ,
@@ -136,41 +171,60 @@ async def token_endpoint(
136171
137172 proxy_base_url = str (request .base_url ).rstrip ("/" )
138173
139- # Exchange code for real GitHub token
174+ # Build token request data
175+ token_data = {
176+ "grant_type" : "authorization_code" ,
177+ "client_id" : mcp_server .client_id ,
178+ "client_secret" : mcp_server .client_secret ,
179+ "code" : code ,
180+ "redirect_uri" : f"{ proxy_base_url } /callback" ,
181+ }
182+
183+ # Forward PKCE code_verifier if present
184+ if code_verifier :
185+ token_data ["code_verifier" ] = code_verifier
186+
187+ # Exchange code for real OAuth token
140188 async_client = get_async_httpx_client (llm_provider = httpxSpecialProvider .Oauth2Check )
141189 response = await async_client .post (
142190 mcp_server .token_url ,
143191 headers = {"Accept" : "application/json" },
144- data = {
145- "client_id" : mcp_server .client_id ,
146- "client_secret" : mcp_server .client_secret ,
147- "code" : code ,
148- "redirect_uri" : f"{ proxy_base_url } /callback" ,
149- },
192+ data = token_data ,
150193 )
151194
152195 response .raise_for_status ()
153- github_token = response .json ()["access_token" ]
154-
155- # Return to Claude in expected OAuth 2 format
196+ token_response = response .json ()
197+ access_token = token_response ["access_token" ]
198+
199+ # Return to client in expected OAuth 2 format
200+ # Only include fields that have values
201+ result = {
202+ "access_token" : access_token ,
203+ "token_type" : token_response .get ("token_type" , "Bearer" ),
204+ "expires_in" : token_response .get ("expires_in" , 3600 ),
205+ }
156206
157- ### return a virtual key in this response
207+ # Add optional fields only if they exist
208+ if "refresh_token" in token_response and token_response ["refresh_token" ]:
209+ result ["refresh_token" ] = token_response ["refresh_token" ]
210+ if "scope" in token_response and token_response ["scope" ]:
211+ result ["scope" ] = token_response ["scope" ]
158212
159- return JSONResponse (
160- {"access_token" : github_token , "token_type" : "Bearer" , "expires_in" : 3600 }
161- )
213+ return JSONResponse (result )
162214
163215
164216@router .get ("/callback" )
165217async def callback (code : str , state : str ):
166218 try :
167- # Decode the state hash to get base_url and original state
168- base_url , original_state = decode_state_hash (state )
219+ # Decode the state hash to get base_url, original state, and PKCE params
220+ state_data = decode_state_hash (state )
221+ base_url = state_data ["base_url" ]
222+ original_state = state_data ["original_state" ]
169223
170- # Exchange code for token with GitHub
224+ # Forward code and original state back to client
171225 params = {"code" : code , "state" : original_state }
172226
173- # Forward token to Claude ephemeral endpoint
227+ # Forward to client's callback endpoint
174228 complete_returned_url = f"{ base_url } ?{ urlencode (params )} "
175229 return RedirectResponse (url = complete_returned_url , status_code = 302 )
176230
0 commit comments