Skip to content

Commit 78d93eb

Browse files
committed
Add OAuth Provider to Python Sdk
1 parent ed25167 commit 78d93eb

File tree

7 files changed

+1497
-674
lines changed

7 files changed

+1497
-674
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"sse-starlette>=1.6.1",
3232
"pydantic-settings>=2.5.2",
3333
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
34+
"pkce>=1.0.3"
3435
]
3536

3637
[project.optional-dependencies]

src/mcp/client/auth.py

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
"""
2+
OAuth 2.0 Client Implementation
3+
4+
This module provides a complete OAuth 2.0 client implementation supporting:
5+
- Authorization Code Flow with PKCE
6+
- Dynamic Client Registration
7+
- Token Refresh
8+
- OAuth Server Metadata Discovery
9+
"""
10+
11+
from typing import Protocol, cast
12+
from urllib.parse import urlencode, urljoin
13+
14+
import httpx
15+
from pkce import generate_pkce_pair # type: ignore
16+
17+
from mcp.shared.auth import (
18+
OAuthClientInformation,
19+
OAuthClientInformationFull,
20+
OAuthClientMetadata,
21+
OAuthMetadata,
22+
OAuthToken,
23+
)
24+
from mcp.types import LATEST_PROTOCOL_VERSION
25+
26+
27+
class OAuthClientProvider(Protocol):
28+
"""Protocol defining the interface for OAuth client providers."""
29+
30+
def get_redirect_url(self) -> str:
31+
"""Get the URL where the user agent will be redirected after authorization."""
32+
...
33+
34+
def get_client_metadata(self) -> OAuthClientMetadata:
35+
"""Get the metadata for this OAuth client."""
36+
...
37+
38+
def get_client_information(self) -> OAuthClientInformation | None:
39+
"""Get the client information as registered with the server.
40+
41+
Returns None if the client is not registered.
42+
"""
43+
...
44+
45+
def save_client_information(self, client_information: OAuthClientInformationFull):
46+
"""Optional Function to save the client information received from the server.
47+
48+
If implemented, this provider will support dynamic client registration.
49+
"""
50+
...
51+
52+
def get_token(self) -> OAuthToken | None:
53+
"""Get any existing OAuth tokens for the current session."""
54+
...
55+
56+
def save_token(self, token: OAuthToken):
57+
"""Save the new OAuth token after successful authorization."""
58+
...
59+
60+
def redirect_to_authorization(self, authorization_url: str):
61+
"""Redirect the user agent to begin the authorization flow."""
62+
...
63+
64+
def get_code_verifier(self) -> str:
65+
"""Get the PKCE code verifier for the current session."""
66+
...
67+
68+
def save_code_verifier(self, pkce_code_verifier: str):
69+
"""Save the PKCE code verifier before redirecting to authorization."""
70+
...
71+
72+
73+
class OAuthAuthorization:
74+
"""Main class for handling OAuth 2.0 authorization flows.
75+
76+
This class implements the OAuth 2.0 Authorization Code Flow with PKCE,
77+
supporting dynamic client registration and token refresh.
78+
"""
79+
80+
def __init__(self, provider: OAuthClientProvider, server_url: str):
81+
"""Initialize the OAuth authorization handler.
82+
83+
Args:
84+
provider: The OAuth client provider implementation
85+
server_url: The base URL of the OAuth server
86+
"""
87+
self.provider = provider
88+
self.server_url = server_url
89+
90+
async def authorize(
91+
self, authorization_code: str | None = None
92+
) -> OAuthToken | None:
93+
"""Main authorization method that handles the complete OAuth flow.
94+
95+
This method will:
96+
1. Check for existing valid tokens
97+
2. Refresh tokens if expired
98+
3. Exchange authorization codes for tokens
99+
4. Start new authorization flows if needed
100+
101+
Args:
102+
authorization_code: Optional authorization code from the server
103+
104+
Returns:
105+
OAuthToken if authorization is successful, None if redirect is needed
106+
"""
107+
token = self.provider.get_token()
108+
if token is not None:
109+
# Returned token is still valid so return the token
110+
if token.expires_in is None or token.expires_in > 0:
111+
return token
112+
elif token.refresh_token is not None:
113+
# Refresh token
114+
refreshed_token = await self.refresh_authorization(token.refresh_token)
115+
self.provider.save_token(refreshed_token)
116+
return refreshed_token
117+
118+
# If we have authorization code, exchange it for an access token
119+
if authorization_code:
120+
token = await self.exchange_authorization(authorization_code)
121+
self.provider.save_token(token)
122+
return token
123+
124+
# If no authorization code, build authorization url and redirect
125+
authorization_url, code_verifier = await self.start_authorization()
126+
self.provider.save_code_verifier(code_verifier)
127+
self.provider.redirect_to_authorization(authorization_url)
128+
return None
129+
130+
async def start_authorization(self) -> tuple[str, str]:
131+
"""Start a new authorization flow by generating PKCE values and
132+
building the authorization URL.
133+
134+
Returns:
135+
Tuple containing:
136+
- The complete authorization URL to redirect the user to
137+
- The PKCE code verifier to be used later in token exchange
138+
"""
139+
metadata = await self.discover_oauth_metadata()
140+
client_info = await self.get_client_information()
141+
142+
response_type = "code"
143+
code_challenge_method = "S256"
144+
145+
if metadata is not None:
146+
if (
147+
metadata.response_types_supported
148+
and response_type not in metadata.response_types_supported
149+
):
150+
raise ValueError(
151+
f"Incompatible auth server: {response_type} response type "
152+
"is not supported"
153+
)
154+
if metadata.code_challenge_methods_supported is None or (
155+
code_challenge_method not in metadata.code_challenge_methods_supported
156+
):
157+
raise ValueError(
158+
f"Incompatible auth server: {code_challenge_method} code "
159+
"challenge method is not supported"
160+
)
161+
authorization_url = str(metadata.authorization_endpoint)
162+
else:
163+
authorization_url = urljoin(self.server_url, "/authorize")
164+
165+
code_verifier, code_challenge = cast(tuple[str, str], generate_pkce_pair())
166+
params: dict[str, str] = {
167+
"response_type": response_type,
168+
"client_id": client_info.client_id,
169+
"redirect_uri": self.provider.get_redirect_url(),
170+
"code_challenge": code_challenge,
171+
"code_challenge_method": code_challenge_method,
172+
}
173+
query_string = urlencode(params)
174+
return (f"{authorization_url}?{query_string}", code_verifier)
175+
176+
async def discover_oauth_metadata(self) -> OAuthMetadata | None:
177+
"""Discover OAuth server metadata using the well-known endpoint.
178+
179+
Implements RFC 8414 OAuth 2.0 Authorization Server Metadata.
180+
181+
Returns:
182+
OAuthMetadata if discovery is successful, None if endpoint returns 404
183+
"""
184+
url = urljoin(self.server_url, "/.well-known/openid-configuration")
185+
async with httpx.AsyncClient() as client:
186+
resp = await client.get(
187+
url, headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
188+
)
189+
if resp.status_code == 404:
190+
return None
191+
elif resp.status_code != 200:
192+
raise ValueError(
193+
f"Failed to discover OAuth metadata: HTTP {resp.status_code} "
194+
f"{resp.text}"
195+
)
196+
return OAuthMetadata(**resp.json())
197+
198+
async def register_client(
199+
self,
200+
metadata: OAuthMetadata | None,
201+
client_metadata: OAuthClientMetadata,
202+
) -> OAuthClientInformationFull:
203+
"""Register the client with the OAuth server.
204+
205+
Implements OAuth 2.0 Dynamic Client Registration (RFC 7591).
206+
207+
Args:
208+
metadata: Optional OAuth server metadata
209+
client_metadata: The client's metadata to register
210+
211+
Returns:
212+
Full client information from the server
213+
"""
214+
url = (
215+
str(metadata.registration_endpoint)
216+
if metadata
217+
else urljoin(self.server_url, "/register")
218+
)
219+
220+
async with httpx.AsyncClient() as client:
221+
resp = await client.post(
222+
url,
223+
headers={"Content-Type": "application/x-www-form-urlencoded"},
224+
json=client_metadata.model_dump(),
225+
)
226+
if resp.status_code != 200:
227+
raise ValueError(
228+
f"Dynamic client registration failed: HTTP {resp.status_code} "
229+
f"{resp.text}"
230+
)
231+
return OAuthClientInformationFull(**resp.json())
232+
233+
async def get_client_information(self) -> OAuthClientInformation:
234+
"""Tries to get the client information from the provider.
235+
236+
If unable to retrieve the client information, this attempts
237+
dynamic registration, saves the client with the provider
238+
and returns the information.
239+
240+
Returns:
241+
Client information
242+
"""
243+
client_info = self.provider.get_client_information()
244+
245+
if client_info is None:
246+
if not hasattr(self.provider, "save_client_information"):
247+
raise ValueError(
248+
"Save Client Information is not supported by this provider, "
249+
"therefore we cannot dynamically register the OAuth Client"
250+
)
251+
252+
client_info = await self.register_client(
253+
metadata=None, client_metadata=self.provider.get_client_metadata()
254+
)
255+
self.provider.save_client_information(client_info)
256+
return OAuthClientInformation(**client_info.model_dump())
257+
258+
return client_info
259+
260+
async def exchange_authorization(self, authorization_code: str) -> OAuthToken:
261+
"""Exchange an authorization code for an access token.
262+
263+
Args:
264+
authorization_code: The authorization code from the server
265+
266+
Returns:
267+
New OAuth token
268+
"""
269+
code_verifier = self.provider.get_code_verifier()
270+
redirect_url = self.provider.get_redirect_url()
271+
272+
return await self._fetch_token(
273+
grant_type="authorization_code",
274+
extra_params={
275+
"code": authorization_code,
276+
"code_verifier": code_verifier,
277+
"redirect_uri": redirect_url,
278+
},
279+
)
280+
281+
async def refresh_authorization(self, refresh_token: str) -> OAuthToken:
282+
"""Exchange a refresh token for a new access token.
283+
284+
Args:
285+
refresh_token: The refresh token to use
286+
287+
Returns:
288+
New OAuth token
289+
"""
290+
return await self._fetch_token(
291+
grant_type="refresh_token",
292+
extra_params={
293+
"refresh_token": refresh_token,
294+
},
295+
)
296+
297+
async def _fetch_token(
298+
self,
299+
grant_type: str,
300+
extra_params: dict[str, str],
301+
) -> OAuthToken:
302+
"""Internal method to fetch tokens from the server.
303+
304+
Handles both authorization code exchange and token refresh.
305+
306+
Args:
307+
grant_type: The OAuth grant type to use
308+
extra_params: Additional parameters for the token request
309+
310+
Returns:
311+
New OAuth token
312+
"""
313+
metadata = await self.discover_oauth_metadata()
314+
if metadata is not None:
315+
token_url = str(metadata.token_endpoint)
316+
if (
317+
metadata.grant_types_supported
318+
and grant_type not in metadata.grant_types_supported
319+
):
320+
raise ValueError(
321+
f"Incompatible auth server: {grant_type} not supported"
322+
)
323+
else:
324+
token_url = urljoin(self.server_url, "/token")
325+
326+
client_info = await self.get_client_information()
327+
params: dict[str, str] = {
328+
"grant_type": grant_type,
329+
"client_id": client_info.client_id,
330+
**extra_params,
331+
}
332+
if client_info.client_secret:
333+
params["client_secret"] = client_info.client_secret
334+
335+
async with httpx.AsyncClient() as client:
336+
resp = await client.post(
337+
token_url,
338+
headers={"Content-Type": "application/x-www-form-urlencoded"},
339+
json=params,
340+
)
341+
if resp.status_code != 200:
342+
raise ValueError(
343+
f"Token request failed for {grant_type}: "
344+
f"HTTP {resp.status_code} {resp.text}"
345+
)
346+
return OAuthToken(**resp.json())

0 commit comments

Comments
 (0)