Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Check for existing users ahead of time on upstream OAuth2 registration
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Nov 13, 2023
1 parent 8a1329d commit 9c94e11
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 109 deletions.
180 changes: 146 additions & 34 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use axum::{
extract::{Path, State},
response::{Html, IntoResponse},
response::{Html, IntoResponse, Response},
Form, TypedHeader,
};
use hyper::StatusCode;
Expand All @@ -35,12 +35,13 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
};
use mas_templates::{
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState,
UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
};
use minijinja::Environment;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::warn;
use ulid::Ulid;

use super::UpstreamSessionsCookie;
Expand Down Expand Up @@ -83,14 +84,6 @@ pub(crate) enum RouteError {
#[error("Invalid form action")]
InvalidFormAction,

#[error("Missing username")]
MissingUsername,

#[error("Policy violation: {violations:?}")]
PolicyViolation {
violations: Vec<mas_policy::Violation>,
},

#[error(transparent)]
Internal(Box<dyn std::error::Error>),
}
Expand All @@ -107,16 +100,6 @@ impl IntoResponse for RouteError {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::PolicyViolation { violations } => {
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
let details = details.join("\n");
let ctx = ErrorContext::new()
.with_description(
"Account registration denied because of policy violation".to_owned(),
)
.with_details(details);
FancyError::new(ctx).into_response()
}
Self::Internal(e) => FancyError::from(e).into_response(),
e => FancyError::from(e).into_response(),
};
Expand Down Expand Up @@ -171,7 +154,7 @@ fn import_claim(
Ok(())
}

#[derive(Deserialize)]
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "lowercase", tag = "action")]
pub(crate) enum FormData {
Register {
Expand All @@ -185,6 +168,10 @@ pub(crate) enum FormData {
Link,
}

impl ToFormState for FormData {
type Field = mas_templates::UpstreamRegisterFormField;
}

#[tracing::instrument(
name = "handlers.upstream_oauth2.link.get",
fields(upstream_oauth_link.id = %link_id),
Expand All @@ -195,6 +182,7 @@ pub(crate) async fn get(
mut rng: BoxRng,
clock: BoxClock,
mut repo: BoxRepository,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
Expand Down Expand Up @@ -339,7 +327,7 @@ pub(crate) async fn get(
.map(|id_token| id_token.into_parts().1)
.unwrap_or_default();

let mut ctx = UpstreamRegister::new(&link);
let mut ctx = UpstreamRegister::default();

let env = {
let mut e = Environment::new();
Expand Down Expand Up @@ -375,6 +363,7 @@ pub(crate) async fn get(
},
)?;

let mut forced_localpart = None;
import_claim(
&env,
provider
Expand All @@ -385,10 +374,59 @@ pub(crate) async fn get(
.unwrap_or("{{ user.preferred_username }}"),
&provider.claims_imports.localpart,
|value, force| {
if force {
// We want to run the policy check on the username if it is forced
forced_localpart = Some(value.clone());
}

ctx.set_localpart(value, force);
},
)?;

// Run the policy check and check for existing users
if let Some(localpart) = forced_localpart {
let maybe_existing_user = repo.user().find_by_username(&localpart).await?;
if let Some(existing_user) = maybe_existing_user {
// The mapper returned a username which already exists, but isn't linked to
// this upstream user.
warn!(username = %localpart, user_id = %existing_user.id, "Localpart template returned an existing username");

// TODO: translate
let ctx = ErrorContext::new()
.with_code("User exists")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which is not linked to that upstream account"#
))
.with_language(&locale);

return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}

let res = policy
.evaluate_upstream_oauth_register(&localpart, None)
.await?;

if !res.valid() {
// TODO: translate
let ctx = ErrorContext::new()
.with_code("Policy error")
.with_description(format!(
r#"Upstream account provider returned {localpart:?} as username,
which does not pass the policy check: {res}"#
))
.with_language(&locale);

return Ok((
cookie_jar,
Html(templates.render_error(&ctx)?).into_response(),
));
}
}

let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);

Html(templates.render_upstream_oauth2_do_register(&ctx)?).into_response()
Expand All @@ -411,10 +449,12 @@ pub(crate) async fn post(
cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> {
) -> Result<Response, RouteError> {
let user_agent = user_agent.map(|ua| ua.as_str().to_owned());
let form = cookie_jar.verify_form(&clock, form)?;

Expand Down Expand Up @@ -449,8 +489,10 @@ pub(crate) async fn post(
return Err(RouteError::SessionConsumed);
}

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let form_state = form.to_form_state();

let session = match (maybe_user_session, link.user_id, form) {
(Some(session), None, FormData::Link) => {
Expand Down Expand Up @@ -495,13 +537,15 @@ pub(crate) async fn post(
.unwrap_or(false);

// Let's try to import the claims from the ID token

let env = {
let mut e = Environment::new();
e.add_global("user", payload);
e
};

// Create a template context in case we need to re-render because of an error
let mut ctx = UpstreamRegister::default();

let mut name = None;
import_claim(
&env,
Expand All @@ -515,8 +559,10 @@ pub(crate) async fn post(
|value, force| {
// Import the display name if it is either forced or the user has requested it
if force || import_display_name {
name = Some(value);
name = Some(value.clone());
}

ctx.set_display_name(value, force);
},
)?;

Expand All @@ -533,8 +579,10 @@ pub(crate) async fn post(
|value, force| {
// Import the email if it is either forced or the user has requested it
if force || import_email {
email = Some(value);
email = Some(value.clone());
}

ctx.set_email(value, force);
},
)?;

Expand All @@ -551,21 +599,85 @@ pub(crate) async fn post(
|value, force| {
// If the username is forced, override whatever was in the form
if force {
username = Some(value);
username = Some(value.clone());
}

ctx.set_localpart(value, force);
},
)?;

let username = username.ok_or(RouteError::MissingUsername)?;
let username = username.filter(|s| !s.is_empty());

let Some(username) = username else {
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);

let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
};

// Check if there is an existing user
let existing_user = repo.user().find_by_username(&username).await?;
if let Some(_existing_user) = existing_user {
// If there is an existing user, we can't create a new one
// with the same username

let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);

let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}

// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.await?;
if !res.valid() {
return Err(RouteError::PolicyViolation {
violations: res.violations,
});
let form_state =
res.violations
.into_iter()
.fold(form_state, |form_state, violation| {
match violation.field.as_deref() {
Some("username") => form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
message: violation.msg,
},
),
_ => form_state.with_error_on_form(FormError::Policy {
message: violation.msg,
}),
}
});

let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}

// Now we can create the user
Expand Down Expand Up @@ -631,5 +743,5 @@ pub(crate) async fn post(

repo.save().await?;

Ok((cookie_jar, post_auth_action.go_next(&url_builder)))
Ok((cookie_jar, post_auth_action.go_next(&url_builder)).into_response())
}
Loading

0 comments on commit 9c94e11

Please sign in to comment.