diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 66876a25b..1dd7b5ef0 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -383,6 +383,11 @@ where get(self::views::register::steps::verify_email::get) .post(self::views::register::steps::verify_email::post), ) + .route( + mas_router::RegisterDisplayName::route(), + get(self::views::register::steps::display_name::get) + .post(self::views::register::steps::display_name::post), + ) .route( mas_router::RegisterFinish::route(), get(self::views::register::steps::finish::get), diff --git a/crates/handlers/src/views/register/steps/display_name.rs b/crates/handlers/src/views/register/steps/display_name.rs new file mode 100644 index 000000000..1af314ecf --- /dev/null +++ b/crates/handlers/src/views/register/steps/display_name.rs @@ -0,0 +1,182 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use anyhow::Context as _; +use axum::{ + extract::{Path, State}, + response::{Html, IntoResponse, Response}, + Form, +}; +use mas_axum_utils::{ + cookies::CookieJar, + csrf::{CsrfExt as _, ProtectedForm}, + FancyError, +}; +use mas_router::{PostAuthAction, UrlBuilder}; +use mas_storage::{BoxClock, BoxRepository, BoxRng}; +use mas_templates::{ + FieldError, RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField, + TemplateContext as _, Templates, ToFormState, +}; +use serde::{Deserialize, Serialize}; +use ulid::Ulid; + +use crate::{views::shared::OptionalPostAuthAction, PreferredLanguage}; + +#[derive(Deserialize, Default)] +#[serde(rename_all = "snake_case")] +enum FormAction { + #[default] + Set, + Skip, +} + +#[derive(Deserialize, Serialize)] +pub(crate) struct DisplayNameForm { + #[serde(skip_serializing, default)] + action: FormAction, + #[serde(default)] + display_name: String, +} + +impl ToFormState for DisplayNameForm { + type Field = mas_templates::RegisterStepsDisplayNameFormField; +} + +#[tracing::instrument( + name = "handlers.views.register.steps.display_name.get", + fields(user_registration.id = %id), + skip_all, + err, +)] +pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, + PreferredLanguage(locale): PreferredLanguage, + State(templates): State, + State(url_builder): State, + mut repo: BoxRepository, + Path(id): Path, + cookie_jar: CookieJar, +) -> Result { + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + + let registration = repo + .user_registration() + .lookup(id) + .await? + .context("Could not find user registration")?; + + // If the registration is completed, we can go to the registration destination + // XXX: this might not be the right thing to do? Maybe an error page would be + // better? + if registration.completed_at.is_some() { + let post_auth_action: Option = registration + .post_auth_action + .map(serde_json::from_value) + .transpose()?; + + return Ok(( + cookie_jar, + OptionalPostAuthAction::from(post_auth_action) + .go_next(&url_builder) + .into_response(), + ) + .into_response()); + } + + let ctx = RegisterStepsDisplayNameContext::new() + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + let content = templates.render_register_steps_display_name(&ctx)?; + + Ok((cookie_jar, Html(content)).into_response()) +} + +#[tracing::instrument( + name = "handlers.views.register.steps.display_name.post", + fields(user_registration.id = %id), + skip_all, + err, +)] +pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, + PreferredLanguage(locale): PreferredLanguage, + State(templates): State, + State(url_builder): State, + mut repo: BoxRepository, + Path(id): Path, + cookie_jar: CookieJar, + Form(form): Form>, +) -> Result { + let registration = repo + .user_registration() + .lookup(id) + .await? + .context("Could not find user registration")?; + + // If the registration is completed, we can go to the registration destination + // XXX: this might not be the right thing to do? Maybe an error page would be + // better? + if registration.completed_at.is_some() { + let post_auth_action: Option = registration + .post_auth_action + .map(serde_json::from_value) + .transpose()?; + + return Ok(( + cookie_jar, + OptionalPostAuthAction::from(post_auth_action) + .go_next(&url_builder) + .into_response(), + ) + .into_response()); + } + + let form = cookie_jar.verify_form(&clock, form)?; + + let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + + let display_name = match form.action { + FormAction::Set => { + let display_name = form.display_name.trim(); + + if display_name.is_empty() || display_name.len() > 255 { + let ctx = RegisterStepsDisplayNameContext::new() + .with_form_state(form.to_form_state().with_error_on_field( + RegisterStepsDisplayNameFormField::DisplayName, + FieldError::Invalid, + )) + .with_csrf(csrf_token.form_value()) + .with_language(locale); + + return Ok(( + cookie_jar, + Html(templates.render_register_steps_display_name(&ctx)?), + ) + .into_response()); + } + + display_name.to_owned() + } + FormAction::Skip => { + // If the user chose to skip, we do the same as Synapse and use the localpart as + // default display name + registration.username.clone() + } + }; + + let registration = repo + .user_registration() + .set_display_name(registration, display_name) + .await?; + + repo.save().await?; + + let destination = mas_router::RegisterFinish::new(registration.id); + return Ok((cookie_jar, url_builder.redirect(&destination)).into_response()); +} diff --git a/crates/handlers/src/views/register/steps/finish.rs b/crates/handlers/src/views/register/steps/finish.rs index 55db0a6c5..2c4679a28 100644 --- a/crates/handlers/src/views/register/steps/finish.rs +++ b/crates/handlers/src/views/register/steps/finish.rs @@ -102,6 +102,14 @@ pub(crate) async fn get( ))); } + // Check that the display name is set + if registration.display_name.is_none() { + return Ok(( + cookie_jar, + url_builder.redirect(&mas_router::RegisterDisplayName::new(registration.id)), + )); + } + // Everuthing is good, let's complete the registration let registration = repo .user_registration() diff --git a/crates/handlers/src/views/register/steps/mod.rs b/crates/handlers/src/views/register/steps/mod.rs index 4d479c352..1b090abb9 100644 --- a/crates/handlers/src/views/register/steps/mod.rs +++ b/crates/handlers/src/views/register/steps/mod.rs @@ -3,5 +3,6 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +pub(crate) mod display_name; pub(crate) mod finish; pub(crate) mod verify_email; diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 54b1a5cd1..fa707a66b 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -444,6 +444,30 @@ impl From> for PasswordRegister { } } +/// `GET|POST /register/steps/:id/display-name` +#[derive(Debug, Clone)] +pub struct RegisterDisplayName { + id: Ulid, +} + +impl RegisterDisplayName { + #[must_use] + pub fn new(id: Ulid) -> Self { + Self { id } + } +} + +impl Route for RegisterDisplayName { + type Query = (); + fn route() -> &'static str { + "/register/steps/:id/display-name" + } + + fn path(&self) -> std::borrow::Cow<'static, str> { + format!("/register/steps/{}/display-name", self.id).into() + } +} + /// `GET|POST /register/steps/:id/verify-email` #[derive(Debug, Clone)] pub struct RegisterVerifyEmail { diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index fea88e58d..f5e652ed4 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -1000,6 +1000,57 @@ impl TemplateContext for RegisterStepsVerifyEmailContext { } } +/// Fields for the display name form +#[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum RegisterStepsDisplayNameFormField { + /// The display name + DisplayName, +} + +impl FormField for RegisterStepsDisplayNameFormField { + fn keep(&self) -> bool { + match self { + Self::DisplayName => true, + } + } +} + +/// Context used by the `display_name.html` template +#[derive(Serialize, Default)] +pub struct RegisterStepsDisplayNameContext { + form: FormState, +} + +impl RegisterStepsDisplayNameContext { + /// Constructs a context for the display name page + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Set the form state + #[must_use] + pub fn with_form_state( + mut self, + form_state: FormState, + ) -> Self { + self.form = form_state; + self + } +} + +impl TemplateContext for RegisterStepsDisplayNameContext { + fn sample(_now: chrono::DateTime, _rng: &mut impl Rng) -> Vec + where + Self: Sized, + { + vec![Self { + form: FormState::default(), + }] + } +} + /// Fields of the account recovery start form #[derive(Serialize, Deserialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[serde(rename_all = "snake_case")] diff --git a/crates/templates/src/lib.rs b/crates/templates/src/lib.rs index 1ae43ed24..07e67cc63 100644 --- a/crates/templates/src/lib.rs +++ b/crates/templates/src/lib.rs @@ -41,6 +41,7 @@ pub use self::{ PostAuthContextInner, ReauthContext, ReauthFormField, RecoveryExpiredContext, RecoveryFinishContext, RecoveryFinishFormField, RecoveryProgressContext, RecoveryStartContext, RecoveryStartFormField, RegisterContext, RegisterFormField, + RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField, RegisterStepsVerifyEmailContext, RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamRegisterFormField, UpstreamSuggestLink, WithCaptcha, WithCsrf, @@ -335,6 +336,9 @@ register_templates! { /// Render the email verification page pub fn render_register_steps_verify_email(WithLanguage>) { "pages/register/steps/verify_email.html" } + /// Render the display name page + pub fn render_register_steps_display_name(WithLanguage>) { "pages/register/steps/display_name.html" } + /// Render the client consent page pub fn render_consent(WithLanguage>>) { "pages/consent.html" } @@ -428,6 +432,7 @@ impl Templates { check::render_register(self, now, rng)?; check::render_password_register(self, now, rng)?; check::render_register_steps_verify_email(self, now, rng)?; + check::render_register_steps_display_name(self, now, rng)?; check::render_consent(self, now, rng)?; check::render_policy_violation(self, now, rng)?; check::render_sso_login(self, now, rng)?; diff --git a/templates/components/button.html b/templates/components/button.html index e956f1c41..3b037f2f5 100644 --- a/templates/components/button.html +++ b/templates/components/button.html @@ -29,6 +29,7 @@ class="", value="", disabled=False, + kind="primary", size="lg", autocomplete=False, autocorrect=False, @@ -39,7 +40,7 @@ type="{{ type }}" {% if disabled %}disabled{% endif %} class="cpd-button {{ class }}" - data-kind="primary" + data-kind="{{ kind }}" data-size="{{ size }}" {% if autocapitalize %}autocapitilize="{{ autocapitilize }}"{% endif %} {% if autocomplete %}autocomplete="{{ autocomplete }}"{% endif %} diff --git a/templates/pages/register/steps/display_name.html b/templates/pages/register/steps/display_name.html new file mode 100644 index 000000000..b4e774ca5 --- /dev/null +++ b/templates/pages/register/steps/display_name.html @@ -0,0 +1,52 @@ +{# +Copyright 2025 New Vector Ltd. + +SPDX-License-Identifier: AGPL-3.0-only +Please see LICENSE in the repository root for full details. +-#} + +{% extends "base.html" %} + +{% block content %} +
+
+ {{ icon.visibility_on() }} +
+
+

{{ _("mas.choose_display_name.headline") }}

+

{{ _("mas.choose_display_name.description") }}

+
+
+ +
+
+ {% if form.errors is not empty %} + {% for error in form.errors %} +
+ {{ errors.form_error_message(error=error) }} +
+ {% endfor %} + {% endif %} + + + + + {% call(f) field.field(label=_("common.display_name"), name="display_name", form_state=form, class="mb-4") %} + + {% endcall %} + + {{ button.button(text=_("action.continue")) }} +
+ +
+ + + {{ button.button(text=_("action.skip"), kind="tertiary") }} +
+
+{% endblock content %} diff --git a/templates/pages/register/steps/verify_email.html b/templates/pages/register/steps/verify_email.html index 491721423..39b1c9abb 100644 --- a/templates/pages/register/steps/verify_email.html +++ b/templates/pages/register/steps/verify_email.html @@ -33,7 +33,6 @@

{{ _("mas.verify_email.headline") }}

{% call(f) field.field(label=_("mas.verify_email.6_digit_code"), name="code", form_state=form, class="mb-4 self-center") %}
%(client_name)s at %(redirect_uri)s wants to access your account.", "@client_wants_access": {