|
| 1 | +""" |
| 2 | +OAuth 2.0 Client Implementation |
| 3 | +
|
| 4 | +This module provides a complete OAuth 2.0 client implementation supporting: |
| 5 | +- Authorization Code Flow with PKCE |
| 6 | +- Dynamic Client Registration |
| 7 | +- Token Refresh |
| 8 | +- OAuth Server Metadata Discovery |
| 9 | +""" |
| 10 | + |
| 11 | +from typing import Protocol, cast |
| 12 | +from urllib.parse import urlencode, urljoin |
| 13 | + |
| 14 | +import httpx |
| 15 | +from pkce import generate_pkce_pair # type: ignore |
| 16 | + |
| 17 | +from mcp.shared.auth import ( |
| 18 | + OAuthClientInformation, |
| 19 | + OAuthClientInformationFull, |
| 20 | + OAuthClientMetadata, |
| 21 | + OAuthMetadata, |
| 22 | + OAuthToken, |
| 23 | +) |
| 24 | +from mcp.types import LATEST_PROTOCOL_VERSION |
| 25 | + |
| 26 | + |
| 27 | +class OAuthClientProvider(Protocol): |
| 28 | + """Protocol defining the interface for OAuth client providers.""" |
| 29 | + |
| 30 | + def get_redirect_url(self) -> str: |
| 31 | + """Get the URL where the user agent will be redirected after authorization.""" |
| 32 | + ... |
| 33 | + |
| 34 | + def get_client_metadata(self) -> OAuthClientMetadata: |
| 35 | + """Get the metadata for this OAuth client.""" |
| 36 | + ... |
| 37 | + |
| 38 | + def get_client_information(self) -> OAuthClientInformation | None: |
| 39 | + """Get the client information as registered with the server. |
| 40 | +
|
| 41 | + Returns None if the client is not registered. |
| 42 | + """ |
| 43 | + ... |
| 44 | + |
| 45 | + def save_client_information(self, client_information: OAuthClientInformationFull): |
| 46 | + """Optional Function to save the client information received from the server. |
| 47 | +
|
| 48 | + If implemented, this provider will support dynamic client registration. |
| 49 | + """ |
| 50 | + ... |
| 51 | + |
| 52 | + def get_token(self) -> OAuthToken | None: |
| 53 | + """Get any existing OAuth tokens for the current session.""" |
| 54 | + ... |
| 55 | + |
| 56 | + def save_token(self, token: OAuthToken): |
| 57 | + """Save the new OAuth token after successful authorization.""" |
| 58 | + ... |
| 59 | + |
| 60 | + def redirect_to_authorization(self, authorization_url: str): |
| 61 | + """Redirect the user agent to begin the authorization flow.""" |
| 62 | + ... |
| 63 | + |
| 64 | + def get_code_verifier(self) -> str: |
| 65 | + """Get the PKCE code verifier for the current session.""" |
| 66 | + ... |
| 67 | + |
| 68 | + def save_code_verifier(self, pkce_code_verifier: str): |
| 69 | + """Save the PKCE code verifier before redirecting to authorization.""" |
| 70 | + ... |
| 71 | + |
| 72 | + |
| 73 | +class OAuthAuthorization: |
| 74 | + """Main class for handling OAuth 2.0 authorization flows. |
| 75 | +
|
| 76 | + This class implements the OAuth 2.0 Authorization Code Flow with PKCE, |
| 77 | + supporting dynamic client registration and token refresh. |
| 78 | + """ |
| 79 | + |
| 80 | + def __init__(self, provider: OAuthClientProvider, server_url: str): |
| 81 | + """Initialize the OAuth authorization handler. |
| 82 | +
|
| 83 | + Args: |
| 84 | + provider: The OAuth client provider implementation |
| 85 | + server_url: The base URL of the OAuth server |
| 86 | + """ |
| 87 | + self.provider = provider |
| 88 | + self.server_url = server_url |
| 89 | + |
| 90 | + async def authorize( |
| 91 | + self, authorization_code: str | None = None |
| 92 | + ) -> OAuthToken | None: |
| 93 | + """Main authorization method that handles the complete OAuth flow. |
| 94 | +
|
| 95 | + This method will: |
| 96 | + 1. Check for existing valid tokens |
| 97 | + 2. Refresh tokens if expired |
| 98 | + 3. Exchange authorization codes for tokens |
| 99 | + 4. Start new authorization flows if needed |
| 100 | +
|
| 101 | + Args: |
| 102 | + authorization_code: Optional authorization code from the server |
| 103 | +
|
| 104 | + Returns: |
| 105 | + OAuthToken if authorization is successful, None if redirect is needed |
| 106 | + """ |
| 107 | + token = self.provider.get_token() |
| 108 | + if token is not None: |
| 109 | + # Returned token is still valid so return the token |
| 110 | + if token.expires_in is None or token.expires_in > 0: |
| 111 | + return token |
| 112 | + elif token.refresh_token is not None: |
| 113 | + # Refresh token |
| 114 | + refreshed_token = await self.refresh_authorization(token.refresh_token) |
| 115 | + self.provider.save_token(refreshed_token) |
| 116 | + return refreshed_token |
| 117 | + |
| 118 | + # If we have authorization code, exchange it for an access token |
| 119 | + if authorization_code: |
| 120 | + token = await self.exchange_authorization(authorization_code) |
| 121 | + self.provider.save_token(token) |
| 122 | + return token |
| 123 | + |
| 124 | + # If no authorization code, build authorization url and redirect |
| 125 | + authorization_url, code_verifier = await self.start_authorization() |
| 126 | + self.provider.save_code_verifier(code_verifier) |
| 127 | + self.provider.redirect_to_authorization(authorization_url) |
| 128 | + return None |
| 129 | + |
| 130 | + async def start_authorization(self) -> tuple[str, str]: |
| 131 | + """Start a new authorization flow by generating PKCE values and |
| 132 | + building the authorization URL. |
| 133 | +
|
| 134 | + Returns: |
| 135 | + Tuple containing: |
| 136 | + - The complete authorization URL to redirect the user to |
| 137 | + - The PKCE code verifier to be used later in token exchange |
| 138 | + """ |
| 139 | + metadata = await self.discover_oauth_metadata() |
| 140 | + client_info = await self.get_client_information() |
| 141 | + |
| 142 | + response_type = "code" |
| 143 | + code_challenge_method = "S256" |
| 144 | + |
| 145 | + if metadata is not None: |
| 146 | + if ( |
| 147 | + metadata.response_types_supported |
| 148 | + and response_type not in metadata.response_types_supported |
| 149 | + ): |
| 150 | + raise ValueError( |
| 151 | + f"Incompatible auth server: {response_type} response type " |
| 152 | + "is not supported" |
| 153 | + ) |
| 154 | + if metadata.code_challenge_methods_supported is None or ( |
| 155 | + code_challenge_method not in metadata.code_challenge_methods_supported |
| 156 | + ): |
| 157 | + raise ValueError( |
| 158 | + f"Incompatible auth server: {code_challenge_method} code " |
| 159 | + "challenge method is not supported" |
| 160 | + ) |
| 161 | + authorization_url = str(metadata.authorization_endpoint) |
| 162 | + else: |
| 163 | + authorization_url = urljoin(self.server_url, "/authorize") |
| 164 | + |
| 165 | + code_verifier, code_challenge = cast(tuple[str, str], generate_pkce_pair()) |
| 166 | + params: dict[str, str] = { |
| 167 | + "response_type": response_type, |
| 168 | + "client_id": client_info.client_id, |
| 169 | + "redirect_uri": self.provider.get_redirect_url(), |
| 170 | + "code_challenge": code_challenge, |
| 171 | + "code_challenge_method": code_challenge_method, |
| 172 | + } |
| 173 | + query_string = urlencode(params) |
| 174 | + return (f"{authorization_url}?{query_string}", code_verifier) |
| 175 | + |
| 176 | + async def discover_oauth_metadata(self) -> OAuthMetadata | None: |
| 177 | + """Discover OAuth server metadata using the well-known endpoint. |
| 178 | +
|
| 179 | + Implements RFC 8414 OAuth 2.0 Authorization Server Metadata. |
| 180 | +
|
| 181 | + Returns: |
| 182 | + OAuthMetadata if discovery is successful, None if endpoint returns 404 |
| 183 | + """ |
| 184 | + url = urljoin(self.server_url, "/.well-known/openid-configuration") |
| 185 | + async with httpx.AsyncClient() as client: |
| 186 | + resp = await client.get( |
| 187 | + url, headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION} |
| 188 | + ) |
| 189 | + if resp.status_code == 404: |
| 190 | + return None |
| 191 | + elif resp.status_code != 200: |
| 192 | + raise ValueError( |
| 193 | + f"Failed to discover OAuth metadata: HTTP {resp.status_code} " |
| 194 | + f"{resp.text}" |
| 195 | + ) |
| 196 | + return OAuthMetadata(**resp.json()) |
| 197 | + |
| 198 | + async def register_client( |
| 199 | + self, |
| 200 | + metadata: OAuthMetadata | None, |
| 201 | + client_metadata: OAuthClientMetadata, |
| 202 | + ) -> OAuthClientInformationFull: |
| 203 | + """Register the client with the OAuth server. |
| 204 | +
|
| 205 | + Implements OAuth 2.0 Dynamic Client Registration (RFC 7591). |
| 206 | +
|
| 207 | + Args: |
| 208 | + metadata: Optional OAuth server metadata |
| 209 | + client_metadata: The client's metadata to register |
| 210 | +
|
| 211 | + Returns: |
| 212 | + Full client information from the server |
| 213 | + """ |
| 214 | + url = ( |
| 215 | + str(metadata.registration_endpoint) |
| 216 | + if metadata |
| 217 | + else urljoin(self.server_url, "/register") |
| 218 | + ) |
| 219 | + |
| 220 | + async with httpx.AsyncClient() as client: |
| 221 | + resp = await client.post( |
| 222 | + url, |
| 223 | + headers={"Content-Type": "application/x-www-form-urlencoded"}, |
| 224 | + json=client_metadata.model_dump(), |
| 225 | + ) |
| 226 | + if resp.status_code != 200: |
| 227 | + raise ValueError( |
| 228 | + f"Dynamic client registration failed: HTTP {resp.status_code} " |
| 229 | + f"{resp.text}" |
| 230 | + ) |
| 231 | + return OAuthClientInformationFull(**resp.json()) |
| 232 | + |
| 233 | + async def get_client_information(self) -> OAuthClientInformation: |
| 234 | + """Tries to get the client information from the provider. |
| 235 | +
|
| 236 | + If unable to retrieve the client information, this attempts |
| 237 | + dynamic registration, saves the client with the provider |
| 238 | + and returns the information. |
| 239 | +
|
| 240 | + Returns: |
| 241 | + Client information |
| 242 | + """ |
| 243 | + client_info = self.provider.get_client_information() |
| 244 | + |
| 245 | + if client_info is None: |
| 246 | + if not hasattr(self.provider, "save_client_information"): |
| 247 | + raise ValueError( |
| 248 | + "Save Client Information is not supported by this provider, " |
| 249 | + "therefore we cannot dynamically register the OAuth Client" |
| 250 | + ) |
| 251 | + |
| 252 | + client_info = await self.register_client( |
| 253 | + metadata=None, client_metadata=self.provider.get_client_metadata() |
| 254 | + ) |
| 255 | + self.provider.save_client_information(client_info) |
| 256 | + return OAuthClientInformation(**client_info.model_dump()) |
| 257 | + |
| 258 | + return client_info |
| 259 | + |
| 260 | + async def exchange_authorization(self, authorization_code: str) -> OAuthToken: |
| 261 | + """Exchange an authorization code for an access token. |
| 262 | +
|
| 263 | + Args: |
| 264 | + authorization_code: The authorization code from the server |
| 265 | +
|
| 266 | + Returns: |
| 267 | + New OAuth token |
| 268 | + """ |
| 269 | + code_verifier = self.provider.get_code_verifier() |
| 270 | + redirect_url = self.provider.get_redirect_url() |
| 271 | + |
| 272 | + return await self._fetch_token( |
| 273 | + grant_type="authorization_code", |
| 274 | + extra_params={ |
| 275 | + "code": authorization_code, |
| 276 | + "code_verifier": code_verifier, |
| 277 | + "redirect_uri": redirect_url, |
| 278 | + }, |
| 279 | + ) |
| 280 | + |
| 281 | + async def refresh_authorization(self, refresh_token: str) -> OAuthToken: |
| 282 | + """Exchange a refresh token for a new access token. |
| 283 | +
|
| 284 | + Args: |
| 285 | + refresh_token: The refresh token to use |
| 286 | +
|
| 287 | + Returns: |
| 288 | + New OAuth token |
| 289 | + """ |
| 290 | + return await self._fetch_token( |
| 291 | + grant_type="refresh_token", |
| 292 | + extra_params={ |
| 293 | + "refresh_token": refresh_token, |
| 294 | + }, |
| 295 | + ) |
| 296 | + |
| 297 | + async def _fetch_token( |
| 298 | + self, |
| 299 | + grant_type: str, |
| 300 | + extra_params: dict[str, str], |
| 301 | + ) -> OAuthToken: |
| 302 | + """Internal method to fetch tokens from the server. |
| 303 | +
|
| 304 | + Handles both authorization code exchange and token refresh. |
| 305 | +
|
| 306 | + Args: |
| 307 | + grant_type: The OAuth grant type to use |
| 308 | + extra_params: Additional parameters for the token request |
| 309 | +
|
| 310 | + Returns: |
| 311 | + New OAuth token |
| 312 | + """ |
| 313 | + metadata = await self.discover_oauth_metadata() |
| 314 | + if metadata is not None: |
| 315 | + token_url = str(metadata.token_endpoint) |
| 316 | + if ( |
| 317 | + metadata.grant_types_supported |
| 318 | + and grant_type not in metadata.grant_types_supported |
| 319 | + ): |
| 320 | + raise ValueError( |
| 321 | + f"Incompatible auth server: {grant_type} not supported" |
| 322 | + ) |
| 323 | + else: |
| 324 | + token_url = urljoin(self.server_url, "/token") |
| 325 | + |
| 326 | + client_info = await self.get_client_information() |
| 327 | + params: dict[str, str] = { |
| 328 | + "grant_type": grant_type, |
| 329 | + "client_id": client_info.client_id, |
| 330 | + **extra_params, |
| 331 | + } |
| 332 | + if client_info.client_secret: |
| 333 | + params["client_secret"] = client_info.client_secret |
| 334 | + |
| 335 | + async with httpx.AsyncClient() as client: |
| 336 | + resp = await client.post( |
| 337 | + token_url, |
| 338 | + headers={"Content-Type": "application/x-www-form-urlencoded"}, |
| 339 | + json=params, |
| 340 | + ) |
| 341 | + if resp.status_code != 200: |
| 342 | + raise ValueError( |
| 343 | + f"Token request failed for {grant_type}: " |
| 344 | + f"HTTP {resp.status_code} {resp.text}" |
| 345 | + ) |
| 346 | + return OAuthToken(**resp.json()) |
0 commit comments