Skip to content

Commit 60da682

Browse files
committed
Improved error handling, generic types for provider
1 parent 56f694e commit 60da682

File tree

10 files changed

+490
-116
lines changed

10 files changed

+490
-116
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
)
1515
from mcp.server.auth.json_response import PydanticJSONResponse
1616
from mcp.server.auth.provider import (
17+
AuthorizationErrorCode,
1718
AuthorizationParams,
19+
AuthorizeError,
1820
OAuthServerProvider,
1921
construct_redirect_uri,
2022
)
@@ -49,20 +51,9 @@ class AuthorizationRequest(BaseModel):
4951
)
5052

5153

52-
AuthorizationErrorCode = Literal[
53-
"invalid_request",
54-
"unauthorized_client",
55-
"access_denied",
56-
"unsupported_response_type",
57-
"invalid_scope",
58-
"server_error",
59-
"temporarily_unavailable",
60-
]
61-
62-
6354
class AuthorizationErrorResponse(BaseModel):
6455
error: AuthorizationErrorCode
65-
error_description: str
56+
error_description: str | None
6657
error_uri: AnyUrl | None = None
6758
# must be set if provided in the request
6859
state: str | None = None
@@ -98,16 +89,14 @@ async def handle(self, request: Request) -> Response:
9889

9990
async def error_response(
10091
error: AuthorizationErrorCode,
101-
error_description: str,
92+
error_description: str | None,
10293
attempt_load_client: bool = True,
10394
):
10495
nonlocal client, redirect_uri, state
10596
if client is None and attempt_load_client:
10697
# make last-ditch attempt to load the client
10798
client_id = best_effort_extract_string("client_id", params)
108-
client = client_id and await self.provider.clients_store.get_client(
109-
client_id
110-
)
99+
client = client_id and await self.provider.get_client(client_id)
111100
if redirect_uri is None and client:
112101
# make last-ditch effort to load the redirect uri
113102
if params is not None and "redirect_uri" not in params:
@@ -171,7 +160,7 @@ async def error_response(
171160
)
172161

173162
# Get client information
174-
client = await self.provider.clients_store.get_client(
163+
client = await self.provider.get_client(
175164
auth_request.client_id,
176165
)
177166
if not client:
@@ -210,15 +199,22 @@ async def error_response(
210199
redirect_uri=redirect_uri,
211200
)
212201

213-
# Let the provider pick the next URI to redirect to
214-
return RedirectResponse(
215-
url=await self.provider.authorize(
216-
client,
217-
auth_params,
218-
),
219-
status_code=302,
220-
headers={"Cache-Control": "no-store"},
221-
)
202+
try:
203+
# Let the provider pick the next URI to redirect to
204+
return RedirectResponse(
205+
url=await self.provider.authorize(
206+
client,
207+
auth_params,
208+
),
209+
status_code=302,
210+
headers={"Cache-Control": "no-store"},
211+
)
212+
except AuthorizeError as e:
213+
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
214+
return await error_response(
215+
error=e.error,
216+
error_description=e.error_description,
217+
)
222218

223219
except Exception as validation_error:
224220
# Catch-all for unexpected errors

src/mcp/server/auth/handlers/register.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import secrets
22
import time
33
from dataclasses import dataclass
4-
from typing import Literal
54
from uuid import uuid4
65

76
from pydantic import BaseModel, RootModel, ValidationError
@@ -10,7 +9,11 @@
109

1110
from mcp.server.auth.errors import stringify_pydantic_error
1211
from mcp.server.auth.json_response import PydanticJSONResponse
13-
from mcp.server.auth.provider import OAuthRegisteredClientsStore
12+
from mcp.server.auth.provider import (
13+
OAuthServerProvider,
14+
RegistrationError,
15+
RegistrationErrorCode,
16+
)
1417
from mcp.server.auth.settings import ClientRegistrationOptions
1518
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
1619

@@ -22,18 +25,13 @@ class RegistrationRequest(RootModel):
2225

2326

2427
class RegistrationErrorResponse(BaseModel):
25-
error: Literal[
26-
"invalid_redirect_uri",
27-
"invalid_client_metadata",
28-
"invalid_software_statement",
29-
"unapproved_software_statement",
30-
]
31-
error_description: str
28+
error: RegistrationErrorCode
29+
error_description: str | None
3230

3331

3432
@dataclass
3533
class RegistrationHandler:
36-
clients_store: OAuthRegisteredClientsStore
34+
provider: OAuthServerProvider
3735
options: ClientRegistrationOptions
3836

3937
async def handle(self, request: Request) -> Response:
@@ -116,8 +114,17 @@ async def handle(self, request: Request) -> Response:
116114
software_id=client_metadata.software_id,
117115
software_version=client_metadata.software_version,
118116
)
119-
# Register client
120-
await self.clients_store.register_client(client_info)
117+
try:
118+
# Register client
119+
await self.provider.register_client(client_info)
121120

122-
# Return client information
123-
return PydanticJSONResponse(content=client_info, status_code=201)
121+
# Return client information
122+
return PydanticJSONResponse(content=client_info, status_code=201)
123+
except RegistrationError as e:
124+
# Handle registration errors as defined in RFC 7591 Section 3.2.2
125+
return PydanticJSONResponse(
126+
content=RegistrationErrorResponse(
127+
error=e.error, error_description=e.error_description
128+
),
129+
status_code=400,
130+
)

src/mcp/server/auth/handlers/revoke.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ async def handle(self, request: Request) -> Response:
8181
if token is not None:
8282
break
8383

84+
# if token is not found, just return HTTP 200 per the RFC
8485
if token and token.client_id == client.client_id:
85-
# Revoke token
86+
# Revoke token; provider is not meant to be able to do validation
87+
# at this point that would result in an error
8688
await self.provider.revoke_token(token)
8789

8890
# Return successful empty response

src/mcp/server/auth/handlers/token.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
AuthenticationError,
1717
ClientAuthenticator,
1818
)
19-
from mcp.server.auth.provider import OAuthServerProvider
19+
from mcp.server.auth.provider import OAuthServerProvider, TokenError, TokenErrorCode
2020
from mcp.shared.auth import OAuthToken
2121

2222

@@ -56,14 +56,7 @@ class TokenErrorResponse(BaseModel):
5656
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
5757
"""
5858

59-
error: Literal[
60-
"invalid_request",
61-
"invalid_client",
62-
"invalid_grant",
63-
"unauthorized_client",
64-
"unsupported_grant_type",
65-
"invalid_scope",
66-
]
59+
error: TokenErrorCode
6760
error_description: str | None = None
6861
error_uri: AnyHttpUrl | None = None
6962

@@ -184,10 +177,18 @@ async def handle(self, request: Request):
184177
)
185178
)
186179

187-
# Exchange authorization code for tokens
188-
tokens = await self.provider.exchange_authorization_code(
189-
client_info, auth_code
190-
)
180+
try:
181+
# Exchange authorization code for tokens
182+
tokens = await self.provider.exchange_authorization_code(
183+
client_info, auth_code
184+
)
185+
except TokenError as e:
186+
return self.response(
187+
TokenErrorResponse(
188+
error=e.error,
189+
error_description=e.error_description,
190+
)
191+
)
191192

192193
case RefreshTokenRequest():
193194
refresh_token = await self.provider.load_refresh_token(
@@ -233,9 +234,17 @@ async def handle(self, request: Request):
233234
)
234235
)
235236

236-
# Exchange refresh token for new tokens
237-
tokens = await self.provider.exchange_refresh_token(
238-
client_info, refresh_token, scopes
239-
)
237+
try:
238+
# Exchange refresh token for new tokens
239+
tokens = await self.provider.exchange_refresh_token(
240+
client_info, refresh_token, scopes
241+
)
242+
except TokenError as e:
243+
return self.response(
244+
TokenErrorResponse(
245+
error=e.error,
246+
error_description=e.error_description,
247+
)
248+
)
240249

241250
return self.response(TokenSuccessResponse(root=tokens))

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
from mcp.server.auth.provider import OAuthRegisteredClientsStore
3+
from mcp.server.auth.provider import OAuthServerProvider
44
from mcp.shared.auth import OAuthClientInformationFull
55

66

@@ -20,20 +20,20 @@ class ClientAuthenticator:
2020
logic is skipped.
2121
"""
2222

23-
def __init__(self, clients_store: OAuthRegisteredClientsStore):
23+
def __init__(self, provider: OAuthServerProvider):
2424
"""
2525
Initialize the dependency.
2626
2727
Args:
28-
clients_store: Store to look up client information
28+
provider: Provider to look up client information
2929
"""
30-
self.clients_store = clients_store
30+
self.provider = provider
3131

3232
async def authenticate(
3333
self, client_id: str, client_secret: str | None
3434
) -> OAuthClientInformationFull:
3535
# Look up client information
36-
client = await self.clients_store.get_client(client_id)
36+
client = await self.provider.get_client(client_id)
3737
if not client:
3838
raise AuthenticationError("Invalid client_id")
3939

0 commit comments

Comments
 (0)