Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable CSRF protection globally by default #9334

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions h/accounts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
USERNAME_PATTERN,
)
from h.schemas import validators
from h.schemas.base import CSRFSchema
from h.schemas.forms.accounts.util import PASSWORD_MIN_LENGTH
from h.util.user import format_userid

Expand Down Expand Up @@ -153,7 +152,7 @@ def _privacy_accepted_message():
return privacy_msg


class RegisterSchema(CSRFSchema):
class RegisterSchema(colander.Schema):
username = colander.SchemaNode(
colander.String(),
validator=colander.All(
Expand Down Expand Up @@ -197,13 +196,12 @@ class RegisterSchema(CSRFSchema):
)


class EmailChangeSchema(CSRFSchema):
class EmailChangeSchema(colander.Schema):
email = email_node(title=_("Email address"))
# No validators: all validation is done on the email field
password = password_node(title=_("Confirm password"), hide_until_form_active=True)

def validator(self, node, value):
super().validator(node, value)
Copy link
Contributor Author

@seanh seanh Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're no longer inheriting from a CSRFSchema superclass with a validate() method, so several of these super() calls had to be removed. It'll no longer be possible to accidentally disable CSRF protection by forgetting to call super() here.

exc = colander.Invalid(node)
request = node.bindings["request"]
svc = request.find_service(name="user_password")
Expand All @@ -216,7 +214,7 @@ def validator(self, node, value):
raise exc


class PasswordChangeSchema(CSRFSchema):
class PasswordChangeSchema(colander.Schema):
password = password_node(title=_("Current password"), inactive_label=_("Password"))
new_password = new_password_node(
title=_("New password"), hide_until_form_active=True
Expand All @@ -231,7 +229,6 @@ class PasswordChangeSchema(CSRFSchema):
)

def validator(self, node, value): # pragma: no cover
super().validator(node, value)
exc = colander.Invalid(node)
request = node.bindings["request"]
svc = request.find_service(name="user_password")
Expand All @@ -247,12 +244,10 @@ def validator(self, node, value): # pragma: no cover
raise exc


class DeleteAccountSchema(CSRFSchema):
class DeleteAccountSchema(colander.Schema):
password = password_node(title=_("Confirm password"))

def validator(self, node, value):
super().validator(node, value)

request = node.bindings["request"]
svc = request.find_service(name="user_password")

Expand All @@ -262,7 +257,7 @@ def validator(self, node, value):
raise exc


class NotificationsSchema(CSRFSchema):
class NotificationsSchema(colander.Schema):
types = (("reply", _("Email me when someone replies to one of my annotations.")),)

notifications = colander.SchemaNode(
Expand Down
2 changes: 2 additions & 0 deletions h/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def create_app(_global_config, **settings): # pragma: no cover


def includeme(config): # pragma: no cover
config.set_default_csrf_options(require_csrf=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we're doing instead of always subclassing a CSRFSchema with a validate(): we just ask Pyramid to do automatic CSRF verification on all requests with unsafe HTTP methods, this is opt-out rather than opt-in and our views and schemas don't need to be concerned with it at all.


config.scan("h.subscribers")

config.add_tween("h.tweens.conditional_http_tween_factory", under=EXCVIEW)
Expand Down
9 changes: 8 additions & 1 deletion h/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
form templates in preference to the defaults.
"""

from functools import partial

import deform
import pyramid_jinja2
from markupsafe import Markup
from pyramid import httpexceptions
from pyramid.csrf import get_csrf_token
from pyramid.path import AssetResolver

from h import i18n
Expand Down Expand Up @@ -46,6 +49,10 @@ def __call__(self, template_name, **kwargs):
context = self._system.copy()
context.update(kwargs)

context.setdefault(
"get_csrf_token", partial(get_csrf_token, context["request"])
)
Comment on lines +52 to +54
Copy link
Contributor Author

@seanh seanh Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding Pyramid's get_csrf_token() method to the Jinja2 environment that's used to render our custom Deform templates, so that the base template for forms can automatically add a CSRF token.

TODO: There are lots of templates throughout the codebase that do their own get_csrf_token() calls, it should be possible to remove these now. Actually they may always have been unnecessary? Not sure.


return Markup(template.render(context))


Expand Down Expand Up @@ -75,7 +82,7 @@ def create_form(request, *args, **kwargs):
default) will use the renderer configured in the :py:mod:`h.form` module.
"""
env = request.registry[ENVIRONMENT_KEY]
renderer = Jinja2Renderer(env, {"feature": request.feature})
renderer = Jinja2Renderer(env, {"feature": request.feature, "request": request})
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing request to the renderer so it can add it to the template context above.

kwargs.setdefault("renderer", renderer)

return deform.Form(*args, **kwargs)
Expand Down
6 changes: 2 additions & 4 deletions h/schemas/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

from h import i18n
from h.models.auth_client import GrantType
from h.schemas.base import CSRFSchema, enum_type
from h.schemas.base import enum_type

_ = i18n.TranslationString
GrantTypeSchemaType = enum_type(GrantType)


class CreateAuthClientSchema(CSRFSchema):
class CreateAuthClientSchema(colander.Schema):
name = colander.SchemaNode(
colander.String(),
title=_("Name"),
Expand Down Expand Up @@ -57,8 +57,6 @@ class CreateAuthClientSchema(CSRFSchema):
)

def validator(self, node, value):
super().validator(node, value)

grant_type = value.get("grant_type")
redirect_url = value.get("redirect_url")

Expand Down
28 changes: 0 additions & 28 deletions h/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,14 @@
import copy

import colander
import deform
import jsonschema
from pyramid import httpexceptions
from pyramid.csrf import check_csrf_token, get_csrf_token


@colander.deferred
def deferred_csrf_token(_node, kwargs):
request = kwargs.get("request")
return get_csrf_token(request)


class ValidationError(httpexceptions.HTTPBadRequest):
pass


class CSRFSchema(colander.Schema):
"""
A CSRFSchema backward-compatible with the one from the hem module.

Unlike hem, this doesn't require that the csrf_token appear in the
serialized appstruct.
"""

csrf_token = colander.SchemaNode(
colander.String(),
widget=deform.widget.HiddenWidget(),
default=deferred_csrf_token,
missing=None,
)

def validator(self, node, _value):
request = node.bindings["request"]
check_csrf_token(request)


class JSONSchema:
"""
Validate data according to a JSON Schema.
Expand Down
3 changes: 1 addition & 2 deletions h/schemas/forms/accounts/edit_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from h.accounts import util
from h.models.user import DISPLAY_NAME_MAX_LENGTH
from h.schemas import validators
from h.schemas.base import CSRFSchema

_ = i18n.TranslationString

Expand All @@ -24,7 +23,7 @@ def validate_orcid(node, cstruct):
raise colander.Invalid(node, str(exc)) # noqa: B904


class EditProfileSchema(CSRFSchema):
class EditProfileSchema(colander.Schema):
display_name = colander.SchemaNode(
colander.String(),
missing=None,
Expand Down
5 changes: 1 addition & 4 deletions h/schemas/forms/accounts/forgot_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@

from h import i18n, models
from h.schemas import validators
from h.schemas.base import CSRFSchema

_ = i18n.TranslationString


class ForgotPasswordSchema(CSRFSchema):
class ForgotPasswordSchema(colander.Schema):
email = colander.SchemaNode(
colander.String(),
validator=colander.All(validators.Email()),
Expand All @@ -17,8 +16,6 @@ class ForgotPasswordSchema(CSRFSchema):
)

def validator(self, node, value):
super().validator(node, value)

request = node.bindings["request"]
email = value.get("email")
user = models.User.get_by_email(request.db, email, request.default_authority)
Expand Down
5 changes: 1 addition & 4 deletions h/schemas/forms/accounts/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import deform

from h import i18n
from h.schemas.base import CSRFSchema
from h.services.user import UserNotActivated

_ = i18n.TranslationString
Expand All @@ -25,7 +24,7 @@ def _deferred_password_widget(_node, kwargs):
)


class LoginSchema(CSRFSchema):
class LoginSchema(colander.Schema):
username = colander.SchemaNode(
colander.String(),
title=_("Username / email"),
Expand All @@ -36,8 +35,6 @@ class LoginSchema(CSRFSchema):
)

def validator(self, node, value):
super().validator(node, value)

request = node.bindings["request"]
username = value.get("username")
password = value.get("password")
Expand Down
3 changes: 1 addition & 2 deletions h/schemas/forms/accounts/reset_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from itsdangerous import BadData, SignatureExpired

from h import i18n, models
from h.schemas.base import CSRFSchema
from h.schemas.forms.accounts import util

_ = i18n.TranslationString
Expand Down Expand Up @@ -60,7 +59,7 @@ def deserialize(self, node, cstruct):
return user


class ResetPasswordSchema(CSRFSchema):
class ResetPasswordSchema(colander.Schema):
# N.B. this is the field into which the user puts their reset code, but we
# call it `user` because when validated, it will return a `User` object.
user = colander.SchemaNode(
Expand Down
4 changes: 1 addition & 3 deletions h/schemas/forms/admin/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
GROUP_NAME_MIN_LENGTH,
)
from h.schemas import validators
from h.schemas.base import CSRFSchema
from h.util import group_scope

_ = i18n.TranslationString
Expand Down Expand Up @@ -137,7 +136,7 @@ def group_organization_select_widget(_node, kwargs):
return SelectWidget(values=list(zip(org_pubids, org_labels, strict=False)))


class AdminGroupSchema(CSRFSchema):
class AdminGroupSchema(colander.Schema):
def __init__(self, *args):
super().__init__(*args)

Expand Down Expand Up @@ -219,5 +218,4 @@ def __init__(self, *args):
)

def validator(self, node, value):
super().validator(node, value)
username_validator(node, value)
3 changes: 1 addition & 2 deletions h/schemas/forms/admin/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import h.i18n
from h.models.organization import Organization
from h.schemas import validators
from h.schemas.base import CSRFSchema

_ = h.i18n.TranslationString

Expand Down Expand Up @@ -36,7 +35,7 @@ def validate_logo(node, value):
raise colander.Invalid(node, _("Logo does not start with <svg> tag"))


class OrganizationSchema(CSRFSchema):
class OrganizationSchema(colander.Schema):
authority = colander.SchemaNode(colander.String(), title=_("Authority"))

name = colander.SchemaNode(
Expand Down
1 change: 1 addition & 0 deletions h/templates/deform/form.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class="form {{ field.css_class or '' }}
{%- if field.use_inline_editing %} js-form {% endif %}">
<input type="hidden" name="__formid__" value="{{ field.formid }}" />
<input type="hidden" name="csrf_token" value="{{ get_csrf_token() }}">
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base Deform template for all forms now includes the CSRF token so this doesn't have to be done by each individual template.

Forms are rendered by having Deform serialize Colander templates to HTML using these custom Deform templates. When all our Colander templates were subclasses of a CSRFSchema class with a csrf_token field this causes Deform to include a CSRF token field in the serialized HTML (but only if the schema being serialized remembered to subclass CSRFSchema). This replaces that, just automatically putting the CSRF token in every form regardless of schema.


<div class="form__backdrop" data-ref="formBackdrop"></div>

Expand Down
1 change: 1 addition & 0 deletions h/views/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def post(self):
request_param="response_mode=web_message",
is_authenticated=True,
renderer="h:templates/oauth/authorize_web_message.html.jinja2",
require_csrf=False,
Copy link
Contributor Author

@seanh seanh Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIXME: This is needed to get some functests passing. This line should be removed and the failing functests should be fixed to use a CSRF token. (This reveals a bug in the code: this view wasn't requiring CSRF when it should have been.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here to indicate that this is being added only for the tests, and should not be required in actual usage.

)
def post_web_message(self):
"""
Expand Down
1 change: 1 addition & 0 deletions h/views/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def add_api_view( # noqa: PLR0913
`route_name` must be specified.
:param dict **settings: Arguments to pass on to ``config.add_view``
"""
settings.setdefault("require_csrf", False)
settings.setdefault("renderer", "json")
settings.setdefault("decorator", (cors_policy, version_media_type_header(subtype)))

Expand Down
5 changes: 1 addition & 4 deletions tests/functional/h/views/admin/permissions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def test_accessible_by_staff(self, app, url, accessible):

assert res.status_code == 200 if accessible else 404

GROUP_PAGES = (
("POST", "/admin/groups/delete/{pubid}", 302),
("GET", "/admin/groups/{pubid}", 200),
)
Comment on lines -38 to -41
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIXME: The POST to /admin/groups/delete is now failing because the test doesn't include a CSRF token. This actually reveals a bug in the code: that endpoint wasn't doing CSRF verification. Now it is. Rather than deleting the test case this functest should be fixed to use a CSRF token.

GROUP_PAGES = (("GET", "/admin/groups/{pubid}", 200),)

@pytest.mark.usefixtures("with_logged_in_admin")
@pytest.mark.parametrize("method,url_template,success_code", GROUP_PAGES)
Expand Down
17 changes: 1 addition & 16 deletions tests/unit/h/accounts/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import colander
import pytest
from pyramid.exceptions import BadCSRFToken

from h.accounts import schemas
from h.services.user_password import UserPasswordService
Expand Down Expand Up @@ -156,9 +155,7 @@ def test_it_validates_with_valid_payload(

result = schema.deserialize(valid_params)

assert result == dict(
valid_params, privacy_accepted=True, comms_opt_in=None, csrf_token=None
)
Comment on lines -159 to -161
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test was expecting the Colander schema to include csrf_token: None in the deserialized data if there was no csrf_token in the input. I've changed it to omit csrf_token from the output in this case.

assert result == dict(valid_params, privacy_accepted=True, comms_opt_in=None)

@pytest.fixture
def valid_params(self):
Expand Down Expand Up @@ -194,18 +191,6 @@ def test_it_is_valid_if_email_same_as_users_existing_email(

schema.deserialize({"email": user.email, "password": "flibble"})

def test_it_is_invalid_if_csrf_token_missing(self, pyramid_request, schema):
del pyramid_request.headers["X-CSRF-Token"]

with pytest.raises(BadCSRFToken):
schema.deserialize({"email": "[email protected]", "password": "flibble"})

def test_it_is_invalid_if_csrf_token_wrong(self, pyramid_request, schema):
pyramid_request.headers["X-CSRF-Token"] = "WRONG"

with pytest.raises(BadCSRFToken):
schema.deserialize({"email": "[email protected]", "password": "flibble"})
Comment on lines -197 to -207
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Colander schemas no longer do CSRF verification. Several unittests like this can now be removed. (In most cases developers had forgotten to add CSRF test cases for their schemas anyway.)


def test_it_is_invalid_if_password_wrong(self, schema, user_password_service):
user_password_service.check_password.return_value = False

Expand Down
Loading
Loading