Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions src/app/routes/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import secrets
import hashlib
import base64
from typing import Dict
from typing import Dict, Optional
import logging

from flask import Blueprint, jsonify, request, url_for, redirect, render_template
Expand All @@ -31,6 +31,27 @@
TOKEN_RESOURCES: Dict[str, str] = {}


def _resolve_client(client_id: Optional[str]) -> Optional[OAuthClient]:
"""Resolve an OAuth client by id or fall back to the single registered one."""

if client_id:
return OAuthClient.query.filter_by(client_id=client_id).first()

# When no client_id is provided, fall back to the sole registered client.
clients = OAuthClient.query.limit(2).all()
if len(clients) == 1:
logger.info(
"OAuth: defaulted to sole client", extra={"client_id": clients[0].client_id}
)
return clients[0]

logger.info(
"OAuth: unable to resolve client without id",
extra={"registered_clients": len(clients)},
)
return None


def canonical_mcp_resource() -> str:
"""Return the canonical MCP server URI (no trailing slash).

Expand Down Expand Up @@ -229,17 +250,19 @@ def issue_token():
if not data:
data = request.get_json(silent=True) or {}

client_id = data.get('client_id')
client_id_param = data.get('client_id')
client_secret = data.get('client_secret')
client = OAuthClient.query.filter_by(client_id=client_id).first()
client = _resolve_client(client_id_param)
if not client or (client_secret and client.client_secret != client_secret):
logger.info("OAuth: invalid_client at token endpoint", extra={
"client_id": client_id,
"client_id": client_id_param,
"has_secret": bool(client_secret),
"remote_addr": request.remote_addr,
})
return _json_error(401, 'invalid_client', 'Client authentication failed')

client_id = client.client_id

grant_type = data.get('grant_type', 'client_credentials')
resource = data.get('resource')
if not resource:
Expand All @@ -260,7 +283,7 @@ def issue_token():
if grant_type == 'authorization_code':
code = data.get('code')
code_verifier = data.get('code_verifier')
redirect_uri = data.get('redirect_uri')
redirect_uri = data.get('redirect_uri') or client.redirect_uri
if not code or not code_verifier or not redirect_uri:
logger.info("OAuth: authorization_code missing fields", extra={
"client_id": client_id,
Expand Down Expand Up @@ -339,20 +362,25 @@ def issue_token():
@login_required
def authorize():
"""Display a consent page and issue an authorization code."""
client_id = request.args.get('client_id')
client_id_param = request.args.get('client_id')
client = _resolve_client(client_id_param)
redirect_uri = request.args.get('redirect_uri')
code_challenge = request.values.get('code_challenge')
state = request.values.get('state')
scope = request.values.get('scope')
resource = request.values.get('resource')
client = OAuthClient.query.filter_by(client_id=client_id).first()
if client and not redirect_uri:
redirect_uri = client.redirect_uri

if not client or (client.redirect_uri and client.redirect_uri != redirect_uri):
logger.info("OAuth: authorize invalid_client or redirect mismatch", extra={
"client_id": client_id,
"client_id": client_id_param,
"redirect_uri": redirect_uri,
})
return jsonify({'error': 'invalid_client'}), 400

client_id = client.client_id
code_challenge = request.values.get('code_challenge')
state = request.values.get('state')
scope = request.values.get('scope')
resource = request.values.get('resource')

if request.method == 'POST' and request.form.get('confirm') == 'yes':
code = issue_auth_code(
client_id=client_id,
Expand Down
10 changes: 10 additions & 0 deletions src/tests/test_asgi_mcp_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@


def test_parent_asgi_app_uses_mcp_lifespan(monkeypatch):
monkeypatch.setenv("SECRET_KEY", "testing")
monkeypatch.setenv("RECAPTCHA_PUBLIC_KEY", "testing")
monkeypatch.setenv("RECAPTCHA_PRIVATE_KEY", "testing")
monkeypatch.setenv("CELERY_BROKER_URL", "memory://")
monkeypatch.setenv("CELERY_RESULT_BACKEND", "cache+memory://")
import src.config.env as env
monkeypatch.setattr(env, "SECRET_KEY", "testing", raising=False)
monkeypatch.setattr(env, "RECAPTCHA_PUBLIC_KEY", "testing", raising=False)
monkeypatch.setattr(env, "RECAPTCHA_PRIVATE_KEY", "testing", raising=False)

# Create a dummy ASGI app with a recognizable lifespan callable
class DummyASGI:
async def __call__(self, scope, receive, send):
Expand Down
20 changes: 14 additions & 6 deletions src/tests/test_asgi_sse_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
from starlette.routing import Route, Mount
from starlette.middleware.cors import CORSMiddleware
from starlette.testclient import TestClient
import pytest


@pytest.mark.asyncio
async def test_tasks_events_sse_includes_cors_header_direct():
import asyncio


def test_tasks_events_sse_includes_cors_header_direct(monkeypatch):
monkeypatch.setenv("SECRET_KEY", "testing")
monkeypatch.setenv("RECAPTCHA_PUBLIC_KEY", "testing")
monkeypatch.setenv("RECAPTCHA_PRIVATE_KEY", "testing")
monkeypatch.setenv("CELERY_BROKER_URL", "memory://")
monkeypatch.setenv("CELERY_RESULT_BACKEND", "cache+memory://")
import src.config.env as env
monkeypatch.setattr(env, "SECRET_KEY", "testing", raising=False)
monkeypatch.setattr(env, "RECAPTCHA_PUBLIC_KEY", "testing", raising=False)
monkeypatch.setattr(env, "RECAPTCHA_PRIVATE_KEY", "testing", raising=False)
# Call the SSE handler directly to validate response headers
from src.asgi import sse_task_events
from starlette.requests import Request
Expand All @@ -18,7 +26,7 @@ async def test_tasks_events_sse_includes_cors_header_direct():
"headers": [],
}
req = Request(scope)
resp = await sse_task_events(req)
resp = asyncio.run(sse_task_events(req))
assert resp.headers.get("Access-Control-Allow-Origin") == "*"


Expand Down
5 changes: 1 addition & 4 deletions tests/test_oauth_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def test_oauth_registration_and_access(app, client):
resp = client.post(
'/token',
data={
'client_id': client_id,
'grant_type': 'client_credentials',
'ttl': 3600,
},
Expand Down Expand Up @@ -102,7 +101,6 @@ def test_authorization_code_flow(app, client):
'/authorize',
query_string={
'response_type': 'code',
'client_id': 'cid',
'redirect_uri': 'https://example.com/cb',
'code_challenge': code_challenge,
'state': 'abc',
Expand All @@ -114,7 +112,7 @@ def test_authorization_code_flow(app, client):
token = soup.find('input', {'name': 'csrf_token'})['value']

resp = client.post(
'/authorize?client_id=cid&redirect_uri=https://example.com/cb&state=abc',
'/authorize?redirect_uri=https://example.com/cb&state=abc',
data={'confirm': 'yes', 'code_challenge': code_challenge, 'csrf_token': token},
follow_redirects=False,
)
Expand All @@ -131,7 +129,6 @@ def test_authorization_code_flow(app, client):
'grant_type': 'authorization_code',
'code': code,
'code_verifier': code_verifier,
'client_id': 'cid',
'client_secret': 'secret',
'redirect_uri': 'https://example.com/cb',
},
Expand Down