|
2 | 2 | from dataclasses import dataclass |
3 | 3 | from typing import Any, Literal |
4 | 4 |
|
| 5 | +import httpx |
5 | 6 | from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError |
6 | 7 | from starlette.datastructures import FormData, QueryParams |
7 | 8 | from starlette.requests import Request |
|
16 | 17 | OAuthAuthorizationServerProvider, |
17 | 18 | construct_redirect_uri, |
18 | 19 | ) |
19 | | -from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError |
| 20 | +from mcp.shared.auth import ( |
| 21 | + InvalidRedirectUriError, |
| 22 | + InvalidScopeError, |
| 23 | + OAuthClientInformationFull, |
| 24 | +) |
20 | 25 |
|
21 | 26 | logger = logging.getLogger(__name__) |
22 | 27 |
|
@@ -166,6 +171,29 @@ async def error_response( |
166 | 171 | client = await self.provider.get_client( |
167 | 172 | auth_request.client_id, |
168 | 173 | ) |
| 174 | + if not client: |
| 175 | + # Check if `client_id` is a valid URL for Metadata Document |
| 176 | + if auth_request.client_id.startswith("https://"): |
| 177 | + try: |
| 178 | + async with httpx.AsyncClient() as http_client: |
| 179 | + response = await http_client.get(auth_request.client_id) |
| 180 | + response.raise_for_status() |
| 181 | + metadata = response.json() |
| 182 | + |
| 183 | + if metadata.get("client_id") != auth_request.client_id: |
| 184 | + return await error_response( |
| 185 | + error="invalid_request", |
| 186 | + error_description=f"Client ID '{auth_request.client_id}' \ |
| 187 | + not found in metadata document", |
| 188 | + ) |
| 189 | + |
| 190 | + client = OAuthClientInformationFull(**metadata) |
| 191 | + |
| 192 | + except Exception as e: |
| 193 | + return await error_response( |
| 194 | + error="invalid_request", |
| 195 | + error_description=f"Failed to fetch client metadata from {auth_request.client_id}: {e}", |
| 196 | + ) |
169 | 197 | if not client: |
170 | 198 | # For client_id validation errors, return direct error (no redirect) |
171 | 199 | return await error_response( |
|
0 commit comments