From 89420a2cfc16504609dcd81a9a039579b86c2693 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 13 Nov 2023 12:29:01 +0100 Subject: [PATCH] Refactor the upstream link provider template logic Also adds tests for new account registration through an upstream oauth2 provider --- Cargo.lock | 5 +- crates/axum-utils/src/cookies.rs | 21 +- .../src/upstream_oauth2/provider.rs | 9 + crates/handlers/Cargo.toml | 1 + crates/handlers/src/test_utils.rs | 15 +- crates/handlers/src/upstream_oauth2/link.rs | 521 +++++++++++++----- crates/handlers/src/upstream_oauth2/mod.rs | 1 + .../handlers/src/upstream_oauth2/template.rs | 122 ++++ crates/templates/src/context.rs | 36 +- 9 files changed, 569 insertions(+), 162 deletions(-) create mode 100644 crates/handlers/src/upstream_oauth2/template.rs diff --git a/Cargo.lock b/Cargo.lock index 78487efc1..3f7d9bf95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2946,6 +2946,7 @@ dependencies = [ "axum", "axum-extra", "axum-macros", + "base64ct", "bcrypt", "camino", "chrono", @@ -4957,9 +4958,9 @@ dependencies = [ [[package]] name = "self_cell" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c309e515543e67811222dbc9e3dd7e1056279b782e1dacffe4242b718734fb6" +checksum = "e388332cd64eb80cd595a00941baf513caffae8dce9cfd0467fc9c66397dade6" [[package]] name = "semver" diff --git a/crates/axum-utils/src/cookies.rs b/crates/axum-utils/src/cookies.rs index b761f20a0..bc9da1100 100644 --- a/crates/axum-utils/src/cookies.rs +++ b/crates/axum-utils/src/cookies.rs @@ -55,6 +55,22 @@ impl CookieManager { let key = Key::derive_from(key); Self::new(base_url, key) } + + #[must_use] + pub fn cookie_jar(&self) -> CookieJar { + let inner = PrivateCookieJar::new(self.key.clone()); + let options = self.options.clone(); + + CookieJar { inner, options } + } + + #[must_use] + pub fn cookie_jar_from_headers(&self, headers: &http::HeaderMap) -> CookieJar { + let inner = PrivateCookieJar::from_headers(headers, self.key.clone()); + let options = self.options.clone(); + + CookieJar { inner, options } + } } #[async_trait] @@ -67,10 +83,7 @@ where async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let cookie_manager = CookieManager::from_ref(state); - let inner = PrivateCookieJar::from_headers(&parts.headers, cookie_manager.key.clone()); - let options = cookie_manager.options.clone(); - - Ok(CookieJar { inner, options }) + Ok(cookie_manager.cookie_jar_from_headers(&parts.headers)) } } diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 3e00ee7e1..3c9056968 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -130,4 +130,13 @@ impl ImportAction { pub fn is_required(&self) -> bool { matches!(self, Self::Require) } + + #[must_use] + pub fn should_import(&self, user_preference: bool) -> bool { + match self { + Self::Ignore => false, + Self::Suggest => user_preference, + Self::Force | Self::Require => true, + } + } } diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index db9d0b0fd..179169ad5 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -51,6 +51,7 @@ pbkdf2 = { version = "0.12.2", features = ["password-hash", "std", "simple", "pa zeroize = "1.6.0" # Various data types and utilities +base64ct = "1.6.0" camino.workspace = true chrono.workspace = true psl = "2.1.4" diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index cceb9727a..2356495c9 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -22,6 +22,7 @@ use axum::{ async_trait, body::{Bytes, HttpBody}, extract::{FromRef, FromRequestParts}, + response::{IntoResponse, IntoResponseParts}, }; use cookie_store::{CookieStore, RawCookie}; use futures_util::future::BoxFuture; @@ -31,7 +32,9 @@ use hyper::{ Request, Response, StatusCode, }; use mas_axum_utils::{ - cookies::CookieManager, http_client_factory::HttpClientFactory, ErrorWrapper, + cookies::{CookieJar, CookieManager}, + http_client_factory::HttpClientFactory, + ErrorWrapper, }; use mas_i18n::Translator; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; @@ -264,6 +267,11 @@ impl TestState { _ => panic!("Unexpected status code: {}", response.status()), } } + + /// Get an empty cookie jar + pub fn cookie_jar(&self) -> CookieJar { + self.cookie_manager.cookie_jar() + } } struct TestGraphQLState { @@ -631,6 +639,11 @@ impl CookieHelper { &url, ); } + + pub fn import(&self, res: impl IntoResponseParts) { + let response = (res, "").into_response(); + self.save_cookies(&response); + } } impl Layer for CookieHelper { diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 65fd83b6e..07e16a9a1 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -24,7 +24,7 @@ use mas_axum_utils::{ sentry::SentryEventID, FancyError, SessionInfoExt, }; -use mas_data_model::{UpstreamOAuthProviderImportAction, User}; +use mas_data_model::User; use mas_jose::jwt::Jwt; use mas_policy::Policy; use mas_router::UrlBuilder; @@ -44,9 +44,13 @@ use thiserror::Error; use tracing::warn; use ulid::Ulid; -use super::UpstreamSessionsCookie; +use super::{template::environment, UpstreamSessionsCookie}; use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage}; +const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}"; +const DEFAULT_DISPLAYNAME_TEMPLATE: &str = "{{ user.name }}"; +const DEFAULT_EMAIL_TEMPLATE: &str = "{{ user.email }}"; + #[derive(Debug, Error)] pub(crate) enum RouteError { /// Couldn't find the link specified in the URL @@ -65,6 +69,10 @@ pub(crate) enum RouteError { #[error("Upstream provider not found")] ProviderNotFound, + /// Required attribute rendered to an empty string + #[error("Template {template:?} rendered to an empty string")] + RequiredAttributeEmpty { template: String }, + /// Required claim was missing in id_token #[error("Template {template:?} could not be rendered from the upstream provider's response for required claim")] RequiredAttributeRender { @@ -85,7 +93,7 @@ pub(crate) enum RouteError { InvalidFormAction, #[error(transparent)] - Internal(Box), + Internal(Box), } impl_from_error_for_route!(mas_templates::TemplateError); @@ -108,39 +116,38 @@ impl IntoResponse for RouteError { } } -/// Utility function to import a claim from the upstream provider's response, -/// based on the preference for that attribute. +/// Utility function to render an attribute template. /// /// # Parameters /// -/// * `name` - The name of the claim, for error reporting -/// * `value` - The value of the claim, if present -/// * `preference` - The preference for this claim -/// * `run` - A function to run if the claim is present. The first argument is -/// the value of the claim, and the second is whether the claim is forced to -/// be used. +/// * `environment` - The minijinja environment to use to render the template +/// * `template` - The template to use to render the claim +/// * `required` - Whether the attribute is required or not /// /// # Errors /// -/// Returns an error if the claim is required but missing. -fn import_claim( +/// Returns an error if the attribute is required but fails to render or is +/// empty +fn render_attribute_template( environment: &Environment, template: &str, - action: &UpstreamOAuthProviderImportAction, - mut run: impl FnMut(String, bool), -) -> Result<(), RouteError> { - // If this claim is ignored, we don't need to do anything. - if action.ignore() { - return Ok(()); - } - + required: bool, +) -> Result, RouteError> { match environment.render_str(template, ()) { - Ok(value) if value.is_empty() => { /* Do nothing on empty strings */ } + Ok(value) if value.is_empty() => { + if required { + return Err(RouteError::RequiredAttributeEmpty { + template: template.to_owned(), + }); + } - Ok(value) => run(value, action.is_forced()), + Ok(None) + } + + Ok(value) => Ok(Some(value)), Err(source) => { - if action.is_required() { + if required { return Err(RouteError::RequiredAttributeRender { template: template.to_owned(), source, @@ -148,10 +155,9 @@ fn import_claim( } tracing::warn!(error = &source as &dyn std::error::Error, %template, "Error while rendering template"); + Ok(None) } } - - Ok(()) } #[derive(Deserialize, Serialize)] @@ -327,105 +333,120 @@ pub(crate) async fn get( .map(|id_token| id_token.into_parts().1) .unwrap_or_default(); - let mut ctx = UpstreamRegister::default(); + let ctx = UpstreamRegister::default(); let env = { - let mut e = Environment::new(); + let mut e = environment(); e.add_global("user", payload); e }; - import_claim( - &env, - provider + let ctx = if provider.claims_imports.displayname.ignore() { + ctx + } else { + let template = provider .claims_imports .displayname .template .as_deref() - .unwrap_or("{{ user.name }}"), - &provider.claims_imports.displayname, - |value, force| { - ctx.set_display_name(value, force); - }, - )?; - - import_claim( - &env, - provider + .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE); + + match render_attribute_template( + &env, + template, + provider.claims_imports.displayname.is_required(), + )? { + Some(value) => ctx + .with_display_name(value, provider.claims_imports.displayname.is_forced()), + None => ctx, + } + }; + + let ctx = if provider.claims_imports.email.ignore() { + ctx + } else { + let template = provider .claims_imports .email .template .as_deref() - .unwrap_or("{{ user.email }}"), - &provider.claims_imports.email, - |value, force| { - ctx.set_email(value, force); - }, - )?; - - let mut forced_localpart = None; - import_claim( - &env, - provider + .unwrap_or(DEFAULT_EMAIL_TEMPLATE); + + match render_attribute_template( + &env, + template, + provider.claims_imports.email.is_required(), + )? { + Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()), + None => ctx, + } + }; + + let ctx = if provider.claims_imports.localpart.ignore() { + ctx + } else { + let template = provider .claims_imports .localpart .template .as_deref() - .unwrap_or("{{ user.preferred_username }}"), - &provider.claims_imports.localpart, - |value, force| { - if force { - // We want to run the policy check on the username if it is forced - forced_localpart = Some(value.clone()); - } - - ctx.set_localpart(value, force); - }, - )?; - - // Run the policy check and check for existing users - if let Some(localpart) = forced_localpart { - let maybe_existing_user = repo.user().find_by_username(&localpart).await?; - if let Some(existing_user) = maybe_existing_user { - // The mapper returned a username which already exists, but isn't linked to - // this upstream user. - warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username"); - - // TODO: translate - let ctx = ErrorContext::new() - .with_code("User exists") - .with_description(format!( - r#"Upstream account provider returned {localpart:?} as username, + .unwrap_or(DEFAULT_LOCALPART_TEMPLATE); + + match render_attribute_template( + &env, + template, + provider.claims_imports.localpart.is_required(), + )? { + Some(localpart) => { + // We could run policy & existing user checks when the user submits the + // form, but this lead to poor UX. This is why we do + // it ahead of time here. + let maybe_existing_user = repo.user().find_by_username(&localpart).await?; + if let Some(existing_user) = maybe_existing_user { + // The mapper returned a username which already exists, but isn't linked + // to this upstream user. + warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username"); + + // TODO: translate + let ctx = ErrorContext::new() + .with_code("User exists") + .with_description(format!( + r#"Upstream account provider returned {localpart:?} as username, which is not linked to that upstream account"# - )) - .with_language(&locale); - - return Ok(( - cookie_jar, - Html(templates.render_error(&ctx)?).into_response(), - )); - } - - let res = policy - .evaluate_upstream_oauth_register(&localpart, None) - .await?; - - if !res.valid() { - // TODO: translate - let ctx = ErrorContext::new() - .with_code("Policy error") - .with_description(format!( - r#"Upstream account provider returned {localpart:?} as username, + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } + + let res = policy + .evaluate_upstream_oauth_register(&localpart, None) + .await?; + + if !res.valid() { + // TODO: translate + let ctx = ErrorContext::new() + .with_code("Policy error") + .with_description(format!( + r#"Upstream account provider returned {localpart:?} as username, which does not pass the policy check: {res}"# - )) - .with_language(&locale); + )) + .with_language(&locale); + + return Ok(( + cookie_jar, + Html(templates.render_error(&ctx)?).into_response(), + )); + } - return Ok(( - cookie_jar, - Html(templates.render_error(&ctx)?).into_response(), - )); + ctx.with_localpart(localpart, provider.claims_imports.localpart.is_forced()) + } + None => ctx, } - } + }; let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale); @@ -496,6 +517,8 @@ pub(crate) async fn post( let session = match (maybe_user_session, link.user_id, form) { (Some(session), None, FormData::Link) => { + // The user is already logged in, the link is not linked to any user, and the + // user asked to link their account. repo.upstream_oauth_link() .associate_to_user(&link, &session.user) .await?; @@ -512,6 +535,11 @@ pub(crate) async fn post( import_display_name, }, ) => { + // The user got the form to register a new account, and is not logged in. + // Depending on the claims_imports, we've let the user choose their username, + // choose whether they want to import the email and display name, or + // not. + // Those fields are Some("on") if the checkbox is checked let import_email = import_email.is_some(); let import_display_name = import_display_name.is_some(); @@ -531,6 +559,7 @@ pub(crate) async fn post( .map(|id_token| id_token.into_parts().1) .unwrap_or_default(); + // Is the email verified according to the upstream provider? let provider_email_verified = payload .get_item(&minijinja::Value::from("email_verified")) .map(|v| v.is_true()) @@ -538,77 +567,91 @@ pub(crate) async fn post( // Let's try to import the claims from the ID token let env = { - let mut e = Environment::new(); + let mut e = environment(); e.add_global("user", payload); e }; // Create a template context in case we need to re-render because of an error - let mut ctx = UpstreamRegister::default(); - - let mut name = None; - import_claim( - &env, - provider + let ctx = UpstreamRegister::default(); + + let display_name = if provider + .claims_imports + .displayname + .should_import(import_display_name) + { + let template = provider .claims_imports .displayname .template .as_deref() - .unwrap_or("{{ user.name }}"), - &provider.claims_imports.displayname, - |value, force| { - // Import the display name if it is either forced or the user has requested it - if force || import_display_name { - name = Some(value.clone()); - } + .unwrap_or(DEFAULT_DISPLAYNAME_TEMPLATE); + + render_attribute_template( + &env, + template, + provider.claims_imports.displayname.is_required(), + )? + } else { + None + }; - ctx.set_display_name(value, force); - }, - )?; + let ctx = if let Some(ref display_name) = display_name { + ctx.with_display_name( + display_name.clone(), + provider.claims_imports.email.is_forced(), + ) + } else { + ctx + }; - let mut email = None; - import_claim( - &env, - provider + let email = if provider.claims_imports.email.should_import(import_email) { + let template = provider .claims_imports .email .template .as_deref() - .unwrap_or("{{ user.email }}"), - &provider.claims_imports.email, - |value, force| { - // Import the email if it is either forced or the user has requested it - if force || import_email { - email = Some(value.clone()); - } + .unwrap_or(DEFAULT_EMAIL_TEMPLATE); + + render_attribute_template( + &env, + template, + provider.claims_imports.email.is_required(), + )? + } else { + None + }; - ctx.set_email(value, force); - }, - )?; + let ctx = if let Some(ref email) = email { + ctx.with_email(email.clone(), provider.claims_imports.email.is_forced()) + } else { + ctx + }; - let mut username = username; - import_claim( - &env, - provider + let forced_username = if provider.claims_imports.localpart.is_forced() { + let template = provider .claims_imports .localpart .template .as_deref() - .unwrap_or("{{ user.preferred_username }}"), - &provider.claims_imports.localpart, - |value, force| { - // If the username is forced, override whatever was in the form - if force { - username = Some(value.clone()); - } - - ctx.set_localpart(value, force); - }, - )?; + .unwrap_or(DEFAULT_LOCALPART_TEMPLATE); + + render_attribute_template( + &env, + template, + provider.claims_imports.email.is_required(), + )? + } else { + None + }; - let username = username.filter(|s| !s.is_empty()); + // If there is no forced username, we can use the one the user entered + let username = forced_username + .or(username) + .filter(|username| !username.is_empty()); let Some(username) = username else { + // We're missing a username, let's re-render the form with an error let form_state = form_state.with_error_on_field( mas_templates::UpstreamRegisterFormField::Username, FieldError::Required, @@ -625,11 +668,16 @@ pub(crate) async fn post( .into_response()); }; + let ctx = ctx.with_localpart( + username.clone(), + provider.claims_imports.localpart.is_forced(), + ); + // Check if there is an existing user let existing_user = repo.user().find_by_username(&username).await?; if let Some(_existing_user) = existing_user { // If there is an existing user, we can't create a new one - // with the same username + // with the same username, show an error let form_state = form_state.with_error_on_field( mas_templates::UpstreamRegisterFormField::Username, @@ -687,7 +735,7 @@ pub(crate) async fn post( let mut job = ProvisionUserJob::new(&user); // If we have a display name, set it during provisioning - if let Some(name) = name { + if let Some(name) = display_name { job = job.set_display_name(name); } @@ -745,3 +793,172 @@ pub(crate) async fn post( Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response()) } + +#[cfg(test)] +mod tests { + use hyper::{header::CONTENT_TYPE, Request, StatusCode}; + use mas_data_model::{ + UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderImportPreference, + }; + use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; + use mas_jose::jwt::{JsonWebSignatureHeader, Jwt}; + use mas_router::Route; + use oauth2_types::scope::{Scope, OPENID}; + use sqlx::PgPool; + + use super::UpstreamSessionsCookie; + use crate::test_utils::{ + init_tracing, CookieHelper, RequestBuilderExt, ResponseExt, TestState, + }; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_register(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + let mut rng = state.rng(); + let cookies = CookieHelper::new(); + + let claims_imports = UpstreamOAuthProviderClaimsImports { + localpart: UpstreamOAuthProviderImportPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Force, + template: None, + }, + email: UpstreamOAuthProviderImportPreference { + action: mas_data_model::UpstreamOAuthProviderImportAction::Force, + template: None, + }, + ..UpstreamOAuthProviderClaimsImports::default() + }; + + let id_token = serde_json::json!({ + "preferred_username": "john", + "email": "john@example.com", + "email_verified": true, + }); + + // Grab a key to sign the id_token + // We could generate a key on the fly, but because we have one available here, + // why not use it? + let key = state + .key_store + .signing_key_for_algorithm(&JsonWebSignatureAlg::Rs256) + .unwrap(); + + let signer = key + .params() + .signing_key_for_alg(&JsonWebSignatureAlg::Rs256) + .unwrap(); + let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Rs256); + let id_token = Jwt::sign_with_rng(&mut rng, header, id_token, &signer).unwrap(); + + // Provision a provider and a link + let mut repo = state.repository().await.unwrap(); + let provider = repo + .upstream_oauth_provider() + .add( + &mut rng, + &state.clock, + "https://example.com/".to_owned(), + Scope::from_iter([OPENID]), + OAuthClientAuthenticationMethod::None, + None, + "client".to_owned(), + None, + claims_imports, + ) + .await + .unwrap(); + + let session = repo + .upstream_oauth_session() + .add( + &mut rng, + &state.clock, + &provider, + "state".to_owned(), + None, + "nonce".to_owned(), + ) + .await + .unwrap(); + + let link = repo + .upstream_oauth_link() + .add(&mut rng, &state.clock, &provider, "subject".to_owned()) + .await + .unwrap(); + + let session = repo + .upstream_oauth_session() + .complete_with_link(&state.clock, session, &link, Some(id_token.into_string())) + .await + .unwrap(); + + repo.save().await.unwrap(); + + let cookie_jar = state.cookie_jar(); + let upstream_sessions = UpstreamSessionsCookie::default() + .add(session.id, provider.id, "state".to_owned(), None) + .add_link_to_session(session.id, link.id) + .unwrap(); + let cookie_jar = upstream_sessions.save(cookie_jar, &state.clock); + cookies.import(cookie_jar); + + let request = Request::get(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).empty(); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::OK); + response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8"); + + // Extract the CSRF token from the response body + let csrf_token = response + .body() + .split("name=\"csrf\" value=\"") + .nth(1) + .unwrap() + .split('\"') + .next() + .unwrap(); + + let request = Request::post(&*mas_router::UpstreamOAuth2Link::new(link.id).path()).form( + serde_json::json!({ + "csrf": csrf_token, + "action": "register", + "import_email": "on", + }), + ); + let request = cookies.with_cookies(request); + let response = state.request(request).await; + cookies.save_cookies(&response); + response.assert_status(StatusCode::SEE_OTHER); + + // Check that we have a registered user, with the email imported + let mut repo = state.repository().await.unwrap(); + let user = repo + .user() + .find_by_username("john") + .await + .unwrap() + .expect("user exists"); + + let link = repo + .upstream_oauth_link() + .find_by_subject(&provider, "subject") + .await + .unwrap() + .expect("link exists"); + + assert_eq!(link.user_id, Some(user.id)); + + let email = repo + .user_email() + .get_primary(&user) + .await + .unwrap() + .expect("email exists"); + + assert_eq!(email.email, "john@example.com"); + assert!(email.confirmed_at.is_some()); + } +} diff --git a/crates/handlers/src/upstream_oauth2/mod.rs b/crates/handlers/src/upstream_oauth2/mod.rs index e9974f714..36dfd0959 100644 --- a/crates/handlers/src/upstream_oauth2/mod.rs +++ b/crates/handlers/src/upstream_oauth2/mod.rs @@ -26,6 +26,7 @@ pub(crate) mod cache; pub(crate) mod callback; mod cookie; pub(crate) mod link; +mod template; use self::cookie::UpstreamSessions as UpstreamSessionsCookie; diff --git a/crates/handlers/src/upstream_oauth2/template.rs b/crates/handlers/src/upstream_oauth2/template.rs new file mode 100644 index 000000000..586a82e0d --- /dev/null +++ b/crates/handlers/src/upstream_oauth2/template.rs @@ -0,0 +1,122 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, sync::Arc}; + +use base64ct::{Base64, Encoding}; +use minijinja::{Environment, Error, ErrorKind, Value}; + +fn split(value: &str, separator: Option<&str>) -> Vec { + value + .split(separator.unwrap_or(" ")) + .map(ToOwned::to_owned) + .collect::>() +} + +fn b64decode(value: &str) -> Result { + let bytes = Base64::decode_vec(value).map_err(|e| { + Error::new( + ErrorKind::InvalidOperation, + "Failed to decode base64 string", + ) + .with_source(e) + })?; + + // It is not obvious, but the cleanest way to get a Value stored as raw bytes is + // to wrap it in an Arc, because Value implements From>> + Ok(Value::from(Arc::new(bytes))) +} + +fn b64encode(bytes: &[u8]) -> String { + Base64::encode_string(bytes) +} + +/// Decode a Tag-Length-Value encoded byte array into a map of tag to value. +fn tlvdecode(bytes: &[u8]) -> Result, Error> { + let mut iter = bytes.iter().copied(); + let mut ret = HashMap::new(); + loop { + // TODO: this assumes the tag and the length are both single bytes, which is not + // always the case with protobufs. We should properly decode varints + // here. + let Some(tag) = iter.next() else { + break; + }; + + let len = iter + .next() + .ok_or_else(|| Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding"))?; + + let mut bytes = Vec::with_capacity(len.into()); + for _ in 0..len { + bytes.push( + iter.next().ok_or_else(|| { + Error::new(ErrorKind::InvalidOperation, "Invalid ILV encoding") + })?, + ); + } + + ret.insert(tag, Value::from(Arc::new(bytes))); + } + + Ok(ret) +} + +fn string(value: &Value) -> String { + value.to_string() +} + +pub fn environment() -> Environment<'static> { + let mut env = Environment::new(); + + env.add_filter("split", split); + env.add_filter("b64decode", b64decode); + env.add_filter("b64encode", b64encode); + env.add_filter("tlvdecode", tlvdecode); + env.add_filter("string", string); + + env +} + +#[cfg(test)] +mod tests { + use super::environment; + + #[test] + fn test_split() { + let env = environment(); + let res = env + .render_str(r#"{{ 'foo, bar' | split(', ') | join(" | ") }}"#, ()) + .unwrap(); + assert_eq!(res, "foo | bar"); + } + + #[test] + fn test_ilvdecode() { + let env = environment(); + let res = env + .render_str( + r#" + {%- set tlv = 'Cg0wLTM4NS0yODA4OS0wEgRtb2Nr' | b64decode | tlvdecode -%} + {%- if tlv[18]|string != 'mock' -%} + {{ "FAIL"/0 }} + {%- endif -%} + {{- tlv[10]|string -}} + "#, + (), + ) + .unwrap(); + assert_eq!(res, "0-385-28089-0"); + } +} diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 228064cea..4d861c661 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -954,24 +954,54 @@ impl UpstreamRegister { Self::default() } - /// Set the suggested localpart + /// Set the imported localpart pub fn set_localpart(&mut self, localpart: String, force: bool) { self.imported_localpart = Some(localpart); self.force_localpart = force; } - /// Set the suggested display name + /// Set the imported localpart + #[must_use] + pub fn with_localpart(self, localpart: String, force: bool) -> Self { + Self { + imported_localpart: Some(localpart), + force_localpart: force, + ..self + } + } + + /// Set the imported display name pub fn set_display_name(&mut self, display_name: String, force: bool) { self.imported_display_name = Some(display_name); self.force_display_name = force; } - /// Set the suggested email + /// Set the imported display name + #[must_use] + pub fn with_display_name(self, display_name: String, force: bool) -> Self { + Self { + imported_display_name: Some(display_name), + force_display_name: force, + ..self + } + } + + /// Set the imported email pub fn set_email(&mut self, email: String, force: bool) { self.imported_email = Some(email); self.force_email = force; } + /// Set the imported email + #[must_use] + pub fn with_email(self, email: String, force: bool) -> Self { + Self { + imported_email: Some(email), + force_email: force, + ..self + } + } + /// Set the form state pub fn set_form_state(&mut self, form_state: FormState) { self.form_state = form_state;