Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion backend/app/gateway/csrf_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
State-changing operations require CSRF protection.
"""

import os
import secrets
from collections.abc import Callable
from urllib.parse import urlsplit

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
Expand All @@ -19,7 +21,7 @@

def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
return _request_scheme(request) == "https"


def generate_csrf_token() -> str:
Expand Down Expand Up @@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS


def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
"""Return normalized host[:port], omitting default ports."""
host = hostname.lower()
if ":" in host and not host.startswith("["):
host = f"[{host}]"

if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return host
return f"{host}:{port}"


def _normalize_origin(origin: str) -> str | None:
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
try:
parsed = urlsplit(origin.strip())
port = parsed.port
except ValueError:
return None

scheme = parsed.scheme.lower()
if scheme not in {"http", "https"} or not parsed.hostname:
return None

# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
return None

return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"


def _configured_cors_origins() -> set[str]:
"""Return explicit configured browser origins that may call auth routes."""
origins = set()
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
origin = raw_origin.strip()
if not origin or origin == "*":
continue
normalized = _normalize_origin(origin)
if normalized:
origins.add(normalized)
return origins


def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
return None
first = value.split(",", 1)[0].strip()
return first or None


def _forwarded_param(request: Request, name: str) -> str | None:
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
forwarded = _first_header_value(request.headers.get("forwarded"))
if not forwarded:
return None

for part in forwarded.split(";"):
key, sep, value = part.strip().partition("=")
if sep and key.lower() == name:
return value.strip().strip('"') or None
return None


def _request_scheme(request: Request) -> str:
"""Resolve the original request scheme from trusted proxy headers."""
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
return scheme.lower()


def _request_origin(request: Request) -> str | None:
"""Build the origin for the URL the browser is targeting."""
scheme = _request_scheme(request)
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc

forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
host = f"{host}:{forwarded_port}"

Comment thread
Hinotoi-agent marked this conversation as resolved.
return _normalize_origin(f"{scheme}://{host}")


def is_allowed_auth_origin(request: Request) -> bool:
"""Allow auth POSTs only from the same origin or explicit configured origins.

Login/register/initialize are exempt from the double-submit token because
first-time browser clients do not have a CSRF token yet. They still create
a session cookie, so browser requests with a hostile Origin header must be
rejected to prevent login CSRF / session fixation. Requests without Origin
are allowed for non-browser clients such as curl and mobile integrations.
"""
origin = request.headers.get("origin")
if not origin:
return True

normalized_origin = _normalize_origin(origin)
if normalized_origin is None:
return False

request_origin = _request_origin(request)
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)


class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""

Expand All @@ -70,6 +175,12 @@ def __init__(self, app: ASGIApp) -> None:
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)

if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
return JSONResponse(
status_code=403,
content={"detail": "Cross-site auth request denied."},
)

if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
Expand Down
219 changes: 219 additions & 0 deletions backend/tests/test_csrf_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Tests for CSRF middleware."""

from fastapi import FastAPI
from starlette.testclient import TestClient

from app.gateway.csrf_middleware import CSRFMiddleware


def _make_app() -> FastAPI:
app = FastAPI()
app.add_middleware(CSRFMiddleware)

@app.post("/api/v1/auth/login/local")
async def login_local():
return {"ok": True}

@app.post("/api/v1/auth/register")
async def register():
return {"ok": True}

@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}

return app


def test_auth_post_rejects_cross_origin_browser_request():
"""CSRF-exempt auth routes must not accept hostile browser origins.

Login/register endpoints intentionally skip the double-submit token because
first-time callers do not have a token yet. They still set an auth session,
so a hostile cross-site form POST must be rejected to avoid login CSRF /
session fixation.
"""
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)

assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."


def test_auth_post_allows_same_origin_browser_request():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)

assert response.status_code == 200
assert response.cookies.get("csrf_token")


def test_auth_post_rejects_malformed_origin_with_path():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example/path"},
)

assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None


def test_auth_post_rejects_malformed_origin_with_invalid_port():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:bad"},
)

assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None


def test_auth_post_allows_same_origin_default_port_equivalence():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:443"},
)

assert response.status_code == 200
assert response.cookies.get("csrf_token")


def test_auth_post_allows_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")

response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "deerflow.example, internal:8000",
},
)

assert response.status_code == 200
assert response.cookies.get("csrf_token")


def test_auth_post_allows_rfc_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")

response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"Forwarded": "proto=https;host=deerflow.example",
},
)

assert response.status_code == 200
assert response.cookies.get("csrf_token")
Comment thread
Hinotoi-agent marked this conversation as resolved.
assert "secure" in response.headers["set-cookie"].lower()


def test_auth_post_allows_explicit_configured_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
client = TestClient(_make_app(), base_url="https://api.example")

response = client.post(
"/api/v1/auth/register",
headers={"Origin": "https://app.example"},
)

assert response.status_code == 200
assert response.cookies.get("csrf_token")


def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
client = TestClient(_make_app(), base_url="https://api.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)

assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."


def test_auth_post_sets_strict_samesite_csrf_cookie():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)

assert response.status_code == 200
set_cookie = response.headers["set-cookie"].lower()
assert "csrf_token=" in set_cookie
assert "samesite=strict" in set_cookie
assert "secure" in set_cookie


def test_auth_post_without_origin_still_allows_non_browser_clients():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post("/api/v1/auth/login/local")

assert response.status_code == 200
assert response.cookies.get("csrf_token")


def test_non_auth_mutation_still_requires_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")

response = client.post(
"/api/threads/abc/runs/stream",
headers={"Origin": "https://deerflow.example"},
)

assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."


def test_non_auth_mutation_allows_valid_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "known-token")

response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "known-token",
},
)

assert response.status_code == 200


def test_non_auth_mutation_rejects_mismatched_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "cookie-token")

response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "header-token",
},
)

assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."
Loading