From d1b7fc5cd956cc5bfcd115ff8357dcc2d5909eba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 11:16:03 +0100 Subject: [PATCH 01/11] setup mfa service endpoints --- src/grpc/desktop_client_mfa.rs | 29 +++++++++++++++++++++++++++++ src/grpc/mod.rs | 30 +++++++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 src/grpc/desktop_client_mfa.rs diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs new file mode 100644 index 000000000..2f0e80b73 --- /dev/null +++ b/src/grpc/desktop_client_mfa.rs @@ -0,0 +1,29 @@ +use crate::db::DbPool; +use tonic::Status; + +use super::proto::{ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest}; + +pub(super) struct ClientMfaServer { + pool: DbPool, +} + +impl ClientMfaServer { + #[must_use] + pub fn new(pool: DbPool) -> Self { + Self { pool } + } + + pub async fn start_client_mfa_login( + &self, + request: ClientMfaStartRequest, + ) -> Result<(), Status> { + todo!() + } + + pub async fn finish_client_mfa_login( + &self, + request: ClientMfaFinishRequest, + ) -> Result { + todo!() + } +} diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index c4006a333..06f7d411d 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -29,6 +29,7 @@ use uuid::Uuid; use self::gateway::{gateway_service_server::GatewayServiceServer, GatewayServer}; use self::{ auth::{auth_service_server::AuthServiceServer, AuthServer}, + desktop_client_mfa::ClientMfaServer, enrollment::EnrollmentServer, password_reset::PasswordResetServer, proto::core_response, @@ -49,6 +50,7 @@ use crate::{ }; mod auth; +mod desktop_client_mfa; pub mod enrollment; #[cfg(feature = "wireguard")] pub(crate) mod gateway; @@ -62,8 +64,7 @@ pub(crate) mod proto { tonic::include_proto!("defguard.proxy"); } -use crate::grpc::proto::CoreError; -use proto::{core_request, proxy_client::ProxyClient, CoreResponse}; +use proto::{core_request, proxy_client::ProxyClient, CoreError, CoreResponse}; // Helper struct used to handle gateway state // gateways are grouped by network @@ -344,7 +345,8 @@ pub async fn run_grpc_bidi_stream( mail_tx.clone(), user_agent_parser, ); - let password_reset_server = PasswordResetServer::new(pool, mail_tx); + let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx); + let client_mfa_server = ClientMfaServer::new(pool); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; let endpoint = endpoint.http2_keep_alive_interval(TEN_SECS); @@ -462,6 +464,28 @@ pub async fn run_grpc_bidi_stream( } } } + // rpc ClientMfaStart (ClientMfaStartRequest) returns (google.protobuf.Empty) + Some(core_request::Payload::ClientMfaStart(request)) => { + match client_mfa_server.start_client_mfa_login(request).await { + Ok(()) => Some(core_response::Payload::Empty(())), + Err(err) => { + error!("client mfa start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } + // rpc ClientMfaFinish (ClientMfaFinishRequest) returns (ClientMfaFinishResponse) + Some(core_request::Payload::ClientMfaFinish(request)) => { + match client_mfa_server.finish_client_mfa_login(request).await { + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaFinish(response_payload)) + } + Err(err) => { + error!("client mfa start error {err}"); + Some(core_response::Payload::CoreError(err.into())) + } + } + } // Reply without payload. None => None, }; From 2022e03d16aab4e3d156d3878ff8dcc97abc20fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 12:41:02 +0100 Subject: [PATCH 02/11] add mfa start response --- src/grpc/desktop_client_mfa.rs | 6 ++++-- src/grpc/mod.rs | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index 2f0e80b73..558c2bccf 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -1,7 +1,9 @@ use crate::db::DbPool; use tonic::Status; -use super::proto::{ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest}; +use super::proto::{ + ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, +}; pub(super) struct ClientMfaServer { pool: DbPool, @@ -16,7 +18,7 @@ impl ClientMfaServer { pub async fn start_client_mfa_login( &self, request: ClientMfaStartRequest, - ) -> Result<(), Status> { + ) -> Result { todo!() } diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 06f7d411d..7287d5580 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -464,10 +464,12 @@ pub async fn run_grpc_bidi_stream( } } } - // rpc ClientMfaStart (ClientMfaStartRequest) returns (google.protobuf.Empty) + // rpc ClientMfaStart (ClientMfaStartRequest) returns (ClientMfaStartResponse) Some(core_request::Payload::ClientMfaStart(request)) => { match client_mfa_server.start_client_mfa_login(request).await { - Ok(()) => Some(core_response::Payload::Empty(())), + Ok(response_payload) => { + Some(core_response::Payload::ClientMfaStart(response_payload)) + } Err(err) => { error!("client mfa start error {err}"); Some(core_response::Payload::CoreError(err.into())) From f1890dea7a2bf60e9a6a53213ca7eecfe2a356da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 15:21:18 +0100 Subject: [PATCH 03/11] implement login start endpoint logic --- src/db/models/device.rs | 3 +- src/db/models/group.rs | 6 +- src/db/models/wireguard.rs | 33 +++-------- src/grpc/desktop_client_mfa.rs | 103 +++++++++++++++++++++++++++++++-- src/grpc/enrollment.rs | 22 +++---- src/grpc/mod.rs | 4 +- src/handlers/auth.rs | 2 +- src/handlers/mail.rs | 2 +- src/handlers/user.rs | 4 +- src/handlers/wireguard.rs | 30 ++-------- src/templates.rs | 4 +- 11 files changed, 137 insertions(+), 76 deletions(-) diff --git a/src/db/models/device.rs b/src/db/models/device.rs index 4cdcfa170..69f24228c 100644 --- a/src/db/models/device.rs +++ b/src/db/models/device.rs @@ -491,7 +491,6 @@ impl Device { pub async fn add_to_all_networks( &self, transaction: &mut PgConnection, - admin_group_name: &str, ) -> Result<(Vec, Vec), DeviceError> { info!("Adding device {} to all existing networks", self.name); let networks = WireguardNetwork::all(&mut *transaction).await?; @@ -528,7 +527,7 @@ impl Device { } if let Ok(wireguard_network_device) = network - .add_device_to_network(&mut *transaction, self, admin_group_name, None) + .add_device_to_network(&mut *transaction, self, None) .await { debug!( diff --git a/src/db/models/group.rs b/src/db/models/group.rs index 2ee5f91ff..a6c192309 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -2,6 +2,7 @@ use model_derive::Model; use sqlx::{query, query_as, query_scalar, Error as SqlxError, PgConnection, PgExecutor}; use crate::db::{models::error::ModelError, User, WireguardNetwork}; +use crate::SERVER_CONFIG; #[derive(Model)] pub struct Group { @@ -97,9 +98,12 @@ impl WireguardNetwork { pub async fn get_allowed_groups( &self, transaction: &mut PgConnection, - admin_group_name: &str, ) -> Result>, ModelError> { debug!("Returning a list of allowed groups for network {self}"); + let admin_group_name = &SERVER_CONFIG + .get() + .expect("defguard config not found") + .admin_groupname; // get allowed groups from DB let mut groups = self.fetch_allowed_groups(&mut *transaction).await?; diff --git a/src/db/models/wireguard.rs b/src/db/models/wireguard.rs index a8107b436..0f24d8262 100644 --- a/src/db/models/wireguard.rs +++ b/src/db/models/wireguard.rs @@ -302,11 +302,10 @@ impl WireguardNetwork { async fn get_allowed_devices( &self, transaction: &mut PgConnection, - admin_group_name: &str, ) -> Result, ModelError> { debug!("Fetching all allowed devices for network {}", self); let devices = match self - .get_allowed_groups(&mut *transaction, admin_group_name) + .get_allowed_groups(&mut *transaction) .await? { // devices need to be filtered by allowed group Some(allowed_groups) => { @@ -338,15 +337,12 @@ impl WireguardNetwork { pub async fn add_all_allowed_devices( &self, transaction: &mut PgConnection, - admin_group_name: &str, ) -> Result<(), ModelError> { info!( "Assigning IPs in network {} for all existing devices ", self ); - let devices = self - .get_allowed_devices(&mut *transaction, admin_group_name) - .await?; + let devices = self.get_allowed_devices(&mut *transaction).await?; for device in devices { device .assign_network_ip(&mut *transaction, self, None) @@ -360,13 +356,10 @@ impl WireguardNetwork { &self, transaction: &mut PgConnection, device: &Device, - admin_group_name: &str, reserved_ips: Option<&[IpAddr]>, ) -> Result { info!("Assigning IP in network {self} for {device}"); - let allowed_devices = self - .get_allowed_devices(&mut *transaction, admin_group_name) - .await?; + let allowed_devices = self.get_allowed_devices(&mut *transaction).await?; let allowed_device_ids: Vec = allowed_devices.iter().filter_map(|dev| dev.id).collect(); if allowed_device_ids.contains(&device.get_id()?) { @@ -389,14 +382,11 @@ impl WireguardNetwork { pub async fn sync_allowed_devices( &self, transaction: &mut PgConnection, - admin_group_name: &str, reserved_ips: Option<&[IpAddr]>, ) -> Result, WireguardNetworkError> { info!("Synchronizing IPs in network {self} for all allowed devices "); // list all allowed devices - let allowed_devices = self - .get_allowed_devices(&mut *transaction, admin_group_name) - .await?; + let allowed_devices = self.get_allowed_devices(&mut *transaction).await?; // convert to a map for easier processing let mut allowed_devices: HashMap = allowed_devices .into_iter() @@ -485,12 +475,9 @@ impl WireguardNetwork { &self, transaction: &mut PgConnection, imported_devices: Vec, - admin_group_name: &str, ) -> Result<(Vec, Vec), WireguardNetworkError> { let network_id = self.get_id()?; - let allowed_devices = self - .get_allowed_devices(&mut *transaction, admin_group_name) - .await?; + let allowed_devices = self.get_allowed_devices(&mut *transaction).await?; // convert to a map for easier processing let allowed_devices: HashMap = allowed_devices .into_iter() @@ -551,14 +538,11 @@ impl WireguardNetwork { &self, transaction: &mut PgConnection, mapped_devices: Vec, - admin_group_name: &str, ) -> Result, WireguardNetworkError> { info!("Mapping user devices for network {}", self); let network_id = self.get_id()?; // get allowed groups for network - let allowed_groups = self - .get_allowed_groups(&mut *transaction, admin_group_name) - .await?; + let allowed_groups = self.get_allowed_groups(&mut *transaction).await?; let mut events = Vec::new(); // use a helper hashmap to avoid repeated queries @@ -629,9 +613,8 @@ impl WireguardNetwork { } // assign IPs in other networks - let (mut all_network_info, _configs) = device - .add_to_all_networks(&mut *transaction, admin_group_name) - .await?; + let (mut all_network_info, _configs) = + device.add_to_all_networks(&mut *transaction).await?; network_info.append(&mut all_network_info); diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index 558c2bccf..df13c70c0 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -1,25 +1,120 @@ -use crate::db::DbPool; +use crate::db::{DbPool, Device, User, UserInfo, WireguardNetwork}; +use crate::handlers::mail::send_email_mfa_code_email; +use crate::mail::Mail; +use tokio::sync::mpsc::UnboundedSender; use tonic::Status; +use uuid::Uuid; use super::proto::{ ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, + MfaMethod, }; pub(super) struct ClientMfaServer { pool: DbPool, + mail_tx: UnboundedSender, } impl ClientMfaServer { #[must_use] - pub fn new(pool: DbPool) -> Self { - Self { pool } + pub fn new(pool: DbPool, mail_tx: UnboundedSender) -> Self { + Self { pool, mail_tx } } pub async fn start_client_mfa_login( &self, request: ClientMfaStartRequest, ) -> Result { - todo!() + info!("Starting desktop client login: {request:?}"); + // fetch location + let Ok(Some(location)) = + WireguardNetwork::find_by_id(&self.pool, request.location_id).await + else { + error!("Failed to find location with ID {}", request.location_id); + return Err(Status::invalid_argument("location not found")); + }; + + // fetch device + let Ok(Some(device)) = Device::find_by_pubkey(&self.pool, &request.pubkey).await else { + error!("Failed to find device with pubkey {}", request.pubkey); + return Err(Status::invalid_argument("device not found")); + }; + + // fetch user + let Ok(Some(user)) = User::find_by_id(&self.pool, device.user_id).await else { + error!("Failed to find user with ID {}", device.user_id); + return Err(Status::invalid_argument("user not found")); + }; + let user_info = UserInfo::from_user(&self.pool, &user).await.map_err(|_| { + error!("Failed to fetch user info for {}", user.username); + Status::internal("unexpected error") + })?; + + // validate user is allowed to connect to a given location + let mut transaction = self.pool.begin().await.map_err(|_| { + error!("Failed to begin transaction"); + Status::internal("unexpected error") + })?; + let allowed_groups = location + .get_allowed_groups(&mut transaction) + .await + .map_err(|err| { + error!("Failed to fetch allowed groups for location {location}: {err:?}"); + Status::internal("unexpected error") + })?; + if let Some(groups) = allowed_groups { + // check if user belongs to one of allowed groups + if !groups + .iter() + .any(|allowed_group| user_info.groups.contains(allowed_group)) + { + error!( + "User {} not allowed to connect to location {location}", + user.username + ); + return Err(Status::unauthenticated("unauthorized")); + } + } + + // check if selected method is enabled + match MfaMethod::try_from(request.method) { + Ok(MfaMethod::Totp) => { + if !user.totp_enabled { + error!("TOTP not enabled for user {}", user.username); + return Err(Status::invalid_argument( + "selected MFA method not available", + )); + } + } + Ok(MfaMethod::Email) => { + if !user.email_mfa_enabled { + error!("Email MFA not enabled for user {}", user.username); + return Err(Status::invalid_argument( + "selected MFA method not available", + )); + } + // send email code + send_email_mfa_code_email(&user, &self.mail_tx, None).map_err(|err| { + error!( + "Failed to send email MFA code for user {}: {err:?}", + user.username + ); + Status::internal("unexpected error") + })?; + } + Err(err) => { + error!("Invalid MFA method selected: {err}"); + return Err(Status::invalid_argument("invalid MFA method selected")); + } + } + + // generate auth token + let token = Uuid::new_v4().into(); + + // store login session + todo!(); + + Ok(ClientMfaStartResponse { token }) } pub async fn finish_client_mfa_login( diff --git a/src/grpc/enrollment.rs b/src/grpc/enrollment.rs index e0a568b96..c81e8db86 100644 --- a/src/grpc/enrollment.rs +++ b/src/grpc/enrollment.rs @@ -283,7 +283,6 @@ impl EnrollmentServer { request: NewDevice, req_device_info: Option, ) -> Result { - let config = SERVER_CONFIG.get().expect("defguard config not found"); debug!("Adding new user device: {request:?}"); let enrollment = self.validate_session(request.token.as_deref()).await?; @@ -319,16 +318,17 @@ impl EnrollmentServer { Status::internal("unexpected error") })?; - let (network_info, configs) = device - .add_to_all_networks(&mut transaction, &config.admin_groupname) - .await - .map_err(|err| { - error!( - "Failed to add device {} to existing networks: {err}", - device.name - ); - Status::internal("unexpected error") - })?; + let (network_info, configs) = + device + .add_to_all_networks(&mut transaction) + .await + .map_err(|err| { + error!( + "Failed to add device {} to existing networks: {err}", + device.name + ); + Status::internal("unexpected error") + })?; self.send_wireguard_event(GatewayEvent::DeviceCreated(DeviceInfo { device: device.clone(), diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 7287d5580..ea21f7dc5 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -345,8 +345,8 @@ pub async fn run_grpc_bidi_stream( mail_tx.clone(), user_agent_parser, ); - let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx); - let client_mfa_server = ClientMfaServer::new(pool); + let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx.clone()); + let client_mfa_server = ClientMfaServer::new(pool, mail_tx); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; let endpoint = endpoint.http2_keep_alive_interval(TEN_SECS); diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index be0643ea5..fdeba9461 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -598,7 +598,7 @@ pub async fn request_email_mfa_code( if let Some(user) = User::find_by_id(&appstate.pool, session.user_id).await? { debug!("Sending email MFA code for user {}", user.username); if user.email_mfa_enabled { - send_email_mfa_code_email(&user, &appstate.mail_tx, &session)?; + send_email_mfa_code_email(&user, &appstate.mail_tx, Some(&session))?; info!("Sent email MFA code for user {}", user.username); Ok(ApiResponse::default()) } else { diff --git a/src/handlers/mail.rs b/src/handlers/mail.rs index 0a273a309..73c8e1cc9 100644 --- a/src/handlers/mail.rs +++ b/src/handlers/mail.rs @@ -385,7 +385,7 @@ pub fn send_email_mfa_activation_email( pub fn send_email_mfa_code_email( user: &User, mail_tx: &UnboundedSender, - session: &Session, + session: Option<&Session>, ) -> Result<(), TemplateError> { debug!("Sending email MFA code mail to {}", user.email); diff --git a/src/handlers/user.rs b/src/handlers/user.rs index c3b650632..e96301905 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -314,9 +314,7 @@ pub async fn modify_user( { let networks = WireguardNetwork::all(&mut *transaction).await?; for network in networks { - let gateway_events = network - .sync_allowed_devices(&mut transaction, &appstate.config.admin_groupname, None) - .await?; + let gateway_events = network.sync_allowed_devices(&mut transaction, None).await?; appstate.send_multiple_wireguard_events(gateway_events); } }; diff --git a/src/handlers/wireguard.rs b/src/handlers/wireguard.rs index 96cbc6b00..0a853df7e 100644 --- a/src/handlers/wireguard.rs +++ b/src/handlers/wireguard.rs @@ -114,9 +114,7 @@ pub async fn create_network( .await?; // generate IP addresses for existing devices - network - .add_all_allowed_devices(&mut transaction, &appstate.config.admin_groupname) - .await?; + network.add_all_allowed_devices(&mut transaction).await?; info!("Assigning IPs for existing devices in network {network}"); match &network.id { @@ -181,9 +179,7 @@ pub async fn modify_network( network .set_allowed_groups(&mut transaction, data.allowed_groups) .await?; - let _events = network - .sync_allowed_devices(&mut transaction, &appstate.config.admin_groupname, None) - .await?; + let _events = network.sync_allowed_devices(&mut transaction, None).await?; match &network.id { Some(network_id) => { @@ -377,22 +373,14 @@ pub async fn import_network( .map(|dev| dev.wireguard_ip) .collect(); let (devices, gateway_events) = network - .handle_imported_devices( - &mut transaction, - imported_devices, - &appstate.config.admin_groupname, - ) + .handle_imported_devices(&mut transaction, imported_devices) .await?; appstate.send_multiple_wireguard_events(gateway_events); // assign IPs for other existing devices info!("Assigning IPs in imported network for remaining existing devices"); let gateway_events = network - .sync_allowed_devices( - &mut transaction, - &appstate.config.admin_groupname, - Some(&reserved_ips), - ) + .sync_allowed_devices(&mut transaction, Some(&reserved_ips)) .await?; appstate.send_multiple_wireguard_events(gateway_events); @@ -434,11 +422,7 @@ pub async fn add_user_devices( // wrap loop in transaction to abort if a device is invalid let mut transaction = appstate.pool.begin().await?; let events = network - .handle_mapped_devices( - &mut transaction, - mapped_devices, - &appstate.config.admin_groupname, - ) + .handle_mapped_devices(&mut transaction, mapped_devices) .await?; appstate.send_multiple_wireguard_events(events); transaction.commit().await?; @@ -500,9 +484,7 @@ pub async fn add_device( device: Device, } - let (network_info, configs) = device - .add_to_all_networks(&mut transaction, &appstate.config.admin_groupname) - .await?; + let (network_info, configs) = device.add_to_all_networks(&mut transaction).await?; let mut network_ips: Vec = Vec::new(); for network_info_item in network_info.clone() { diff --git a/src/templates.rs b/src/templates.rs index 0f9cbb1d3..52a712563 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -261,8 +261,8 @@ pub fn email_mfa_activation_mail(code: u32, session: &Session) -> Result Result { - let (mut tera, mut context) = get_base_tera(None, Some(session), None, None)?; +pub fn email_mfa_code_mail(code: u32, session: Option<&Session>) -> Result { + let (mut tera, mut context) = get_base_tera(None, session, None, None)?; // zero-pad code to make sure it's always 6 digits long context.insert("code", &format!("{code:0>6}")); tera.add_raw_template("mail_email_mfa_code", MAIL_EMAIL_MFA_CODE)?; From 604c0b8bf37e430357327597e5a13f01a8bd5b3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 15:28:32 +0100 Subject: [PATCH 04/11] store client login sessions --- src/grpc/desktop_client_mfa.rs | 14 ++++++++++---- src/grpc/mod.rs | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index df13c70c0..e762389e4 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -1,6 +1,7 @@ use crate::db::{DbPool, Device, User, UserInfo, WireguardNetwork}; use crate::handlers::mail::send_email_mfa_code_email; use crate::mail::Mail; +use std::collections::HashMap; use tokio::sync::mpsc::UnboundedSender; use tonic::Status; use uuid::Uuid; @@ -13,16 +14,21 @@ use super::proto::{ pub(super) struct ClientMfaServer { pool: DbPool, mail_tx: UnboundedSender, + sessions: HashMap, } impl ClientMfaServer { #[must_use] pub fn new(pool: DbPool, mail_tx: UnboundedSender) -> Self { - Self { pool, mail_tx } + Self { + pool, + mail_tx, + sessions: HashMap::new(), + } } pub async fn start_client_mfa_login( - &self, + &mut self, request: ClientMfaStartRequest, ) -> Result { info!("Starting desktop client login: {request:?}"); @@ -109,10 +115,10 @@ impl ClientMfaServer { } // generate auth token - let token = Uuid::new_v4().into(); + let token = Uuid::new_v4().to_string(); // store login session - todo!(); + self.sessions.insert(token.clone(), device); Ok(ClientMfaStartResponse { token }) } diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index ea21f7dc5..2c037d6ec 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -346,7 +346,7 @@ pub async fn run_grpc_bidi_stream( user_agent_parser, ); let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx.clone()); - let client_mfa_server = ClientMfaServer::new(pool, mail_tx); + let mut client_mfa_server = ClientMfaServer::new(pool, mail_tx); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; let endpoint = endpoint.http2_keep_alive_interval(TEN_SECS); From 789dc3373bfb43731139dd70a4756a297b73633a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 16:51:49 +0100 Subject: [PATCH 05/11] implement login finish handler --- src/grpc/desktop_client_mfa.rs | 116 ++++++++++++++++++++++++++++++--- src/grpc/mod.rs | 2 +- 2 files changed, 108 insertions(+), 10 deletions(-) diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index e762389e4..3d360914d 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -1,8 +1,13 @@ -use crate::db::{DbPool, Device, User, UserInfo, WireguardNetwork}; -use crate::handlers::mail::send_email_mfa_code_email; -use crate::mail::Mail; +use crate::{ + db::{ + models::device::{DeviceInfo, DeviceNetworkInfo, WireguardNetworkDevice}, + DbPool, Device, GatewayEvent, User, UserInfo, WireguardNetwork, + }, + handlers::mail::send_email_mfa_code_email, + mail::Mail, +}; use std::collections::HashMap; -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; use tonic::Status; use uuid::Uuid; @@ -11,18 +16,30 @@ use super::proto::{ MfaMethod, }; +struct ClientLoginSession { + pub location: WireguardNetwork, + pub device: Device, + pub user: User, +} + pub(super) struct ClientMfaServer { pool: DbPool, mail_tx: UnboundedSender, - sessions: HashMap, + wireguard_tx: Sender, + sessions: HashMap, } impl ClientMfaServer { #[must_use] - pub fn new(pool: DbPool, mail_tx: UnboundedSender) -> Self { + pub fn new( + pool: DbPool, + mail_tx: UnboundedSender, + wireguard_tx: Sender, + ) -> Self { Self { pool, mail_tx, + wireguard_tx, sessions: HashMap::new(), } } @@ -118,15 +135,96 @@ impl ClientMfaServer { let token = Uuid::new_v4().to_string(); // store login session - self.sessions.insert(token.clone(), device); + self.sessions.insert( + token.clone(), + ClientLoginSession { + location, + device, + user, + }, + ); Ok(ClientMfaStartResponse { token }) } pub async fn finish_client_mfa_login( - &self, + &mut self, request: ClientMfaFinishRequest, ) -> Result { - todo!() + info!("Finishing desktop client login: {request:?}"); + // fetch login session + let Some(session) = self.sessions.remove(&request.token) else { + error!("Client login session not found"); + return Err(Status::invalid_argument("login session not found")); + }; + let device = session.device; + let location = session.location; + let user = session.user; + + // validate email code + if !user.verify_email_mfa_code(request.code) { + error!("Provided email code is not valid"); + return Err(Status::unauthenticated("unauthorized")); + }; + + // begin transaction + let mut transaction = self.pool.begin().await.map_err(|_| { + error!("Failed to begin transaction"); + Status::internal("unexpected error") + })?; + + // fetch device config for the location + let Ok(Some(mut network_device)) = WireguardNetworkDevice::find( + &mut *transaction, + device.id.expect("Missing device ID"), + location.id.expect("Missing location ID"), + ) + .await + else { + error!("Failed to fetch network config for device {device} and location {location}"); + return Err(Status::internal("unexpected error")); + }; + + // generate PSK + let key = WireguardNetwork::genkey(); + network_device.preshared_key = Some(key.public.clone()); + + // authorize device for given location + network_device.is_authorized = true; + + // save updated network config + network_device + .update(&mut *transaction) + .await + .map_err(|err| { + error!("Failed to update device network config {network_device:?}: {err:?}"); + Status::internal("unexpected error") + })?; + + // send gateway event + debug!("Sending `peer_create` message to gateway"); + let device_info = DeviceInfo { + device, + network_info: vec![DeviceNetworkInfo { + network_id: location.id.expect("Missing location ID"), + device_wireguard_ip: network_device.wireguard_ip, + preshared_key: network_device.preshared_key, + }], + }; + let event = GatewayEvent::DeviceCreated(device_info); + self.wireguard_tx.send(event).map_err(|err| { + error!("Error sending WireGuard event: {err}"); + Status::internal("unexpected error") + })?; + + // commit transaction + transaction.commit().await.map_err(|_| { + error!("Failed to commit transaction"); + Status::internal("unexpected error") + })?; + + Ok(ClientMfaFinishResponse { + preshared_key: key.public, + }) } } diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 2c037d6ec..607bdca95 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -346,7 +346,7 @@ pub async fn run_grpc_bidi_stream( user_agent_parser, ); let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx.clone()); - let mut client_mfa_server = ClientMfaServer::new(pool, mail_tx); + let mut client_mfa_server = ClientMfaServer::new(pool, mail_tx, wireguard_tx); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; let endpoint = endpoint.http2_keep_alive_interval(TEN_SECS); From c0a0cb4b851d8dfe6f1f4cf2f41e3318806f22a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 20:42:53 +0100 Subject: [PATCH 06/11] handle devices which never connected --- ...aa5342fbe4d059cb073400544f1ce454d059ac13b24b32cbdbf4.json} | 4 ++-- src/wireguard_peer_disconnect.rs | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) rename .sqlx/{query-580741c18880eb98a7073dbb8e1cd907893fedd70f7d752521d515230397f3ee.json => query-92c38d7487d3aa5342fbe4d059cb073400544f1ce454d059ac13b24b32cbdbf4.json} (82%) diff --git a/.sqlx/query-580741c18880eb98a7073dbb8e1cd907893fedd70f7d752521d515230397f3ee.json b/.sqlx/query-92c38d7487d3aa5342fbe4d059cb073400544f1ce454d059ac13b24b32cbdbf4.json similarity index 82% rename from .sqlx/query-580741c18880eb98a7073dbb8e1cd907893fedd70f7d752521d515230397f3ee.json rename to .sqlx/query-92c38d7487d3aa5342fbe4d059cb073400544f1ce454d059ac13b24b32cbdbf4.json index 8323d4d07..afa7cceb4 100644 --- a/.sqlx/query-580741c18880eb98a7073dbb8e1cd907893fedd70f7d752521d515230397f3ee.json +++ b/.sqlx/query-92c38d7487d3aa5342fbe4d059cb073400544f1ce454d059ac13b24b32cbdbf4.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "WITH stats AS ( SELECT DISTINCT ON (device_id) device_id, endpoint, latest_handshake FROM wireguard_peer_stats WHERE network = $1 ORDER BY device_id, collected_at DESC ) SELECT d.id as \"id?\", d.name, d.wireguard_pubkey, d.user_id, d.created FROM device d JOIN wireguard_network_device wnd ON wnd.device_id = d.id LEFT JOIN stats on d.id = stats.device_id WHERE wnd.wireguard_network_id = $1 AND wnd.is_authorized = true AND (NOW() - stats.latest_handshake) > $2 * interval '1 second'", + "query": "WITH stats AS ( SELECT DISTINCT ON (device_id) device_id, endpoint, latest_handshake FROM wireguard_peer_stats WHERE network = $1 ORDER BY device_id, collected_at DESC ) SELECT d.id as \"id?\", d.name, d.wireguard_pubkey, d.user_id, d.created FROM device d JOIN wireguard_network_device wnd ON wnd.device_id = d.id LEFT JOIN stats on d.id = stats.device_id WHERE wnd.wireguard_network_id = $1 AND wnd.is_authorized = true AND (stats.latest_handshake IS NULL OR (NOW() - stats.latest_handshake) > $2 * interval '1 second')", "describe": { "columns": [ { @@ -43,5 +43,5 @@ false ] }, - "hash": "580741c18880eb98a7073dbb8e1cd907893fedd70f7d752521d515230397f3ee" + "hash": "92c38d7487d3aa5342fbe4d059cb073400544f1ce454d059ac13b24b32cbdbf4" } diff --git a/src/wireguard_peer_disconnect.rs b/src/wireguard_peer_disconnect.rs index ea025cb8f..0e1a09c44 100644 --- a/src/wireguard_peer_disconnect.rs +++ b/src/wireguard_peer_disconnect.rs @@ -70,7 +70,8 @@ pub async fn run_periodic_peer_disconnect( FROM device d \ JOIN wireguard_network_device wnd ON wnd.device_id = d.id \ LEFT JOIN stats on d.id = stats.device_id \ - WHERE wnd.wireguard_network_id = $1 AND wnd.is_authorized = true AND (NOW() - stats.latest_handshake) > $2 * interval '1 second'", + WHERE wnd.wireguard_network_id = $1 AND wnd.is_authorized = true AND \ + (stats.latest_handshake IS NULL OR (NOW() - stats.latest_handshake) > $2 * interval '1 second')", location_id, location.peer_disconnect_threshold as f64 ) From 5f56f128cccb8c63f3eacc1d37c6705ae68b8500 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 16 Jan 2024 20:43:36 +0100 Subject: [PATCH 07/11] fix tests --- tests/wireguard_network_import.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/wireguard_network_import.rs b/tests/wireguard_network_import.rs index 9ddbcc91b..1ac5c5f3a 100644 --- a/tests/wireguard_network_import.rs +++ b/tests/wireguard_network_import.rs @@ -69,7 +69,7 @@ async fn test_config_import() { ); device_1.save(&mut *transaction).await.unwrap(); device_1 - .add_to_all_networks(&mut transaction, &client_state.config.admin_groupname) + .add_to_all_networks(&mut transaction) .await .unwrap(); @@ -80,7 +80,7 @@ async fn test_config_import() { ); device_2.save(&mut *transaction).await.unwrap(); device_2 - .add_to_all_networks(&mut transaction, &client_state.config.admin_groupname) + .add_to_all_networks(&mut transaction) .await .unwrap(); From a9ed7c4ab1b78d5da81c1521da1820fdd54927f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 17 Jan 2024 10:18:14 +0100 Subject: [PATCH 08/11] review fixes --- src/db/models/group.rs | 6 ++++-- src/grpc/desktop_client_mfa.rs | 14 ++++++++------ src/grpc/mod.rs | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/db/models/group.rs b/src/db/models/group.rs index a6c192309..2df043620 100644 --- a/src/db/models/group.rs +++ b/src/db/models/group.rs @@ -1,8 +1,10 @@ use model_derive::Model; use sqlx::{query, query_as, query_scalar, Error as SqlxError, PgConnection, PgExecutor}; -use crate::db::{models::error::ModelError, User, WireguardNetwork}; -use crate::SERVER_CONFIG; +use crate::{ + db::{models::error::ModelError, User, WireguardNetwork}, + SERVER_CONFIG, +}; #[derive(Model)] pub struct Group { diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index 3d360914d..828c0f282 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -17,9 +17,9 @@ use super::proto::{ }; struct ClientLoginSession { - pub location: WireguardNetwork, - pub device: Device, - pub user: User, + location: WireguardNetwork, + device: Device, + user: User, } pub(super) struct ClientMfaServer { @@ -157,9 +157,11 @@ impl ClientMfaServer { error!("Client login session not found"); return Err(Status::invalid_argument("login session not found")); }; - let device = session.device; - let location = session.location; - let user = session.user; + let ClientLoginSession { + device, + location, + user, + } = session; // validate email code if !user.verify_email_mfa_code(request.code) { diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 607bdca95..43a2cf589 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -471,7 +471,7 @@ pub async fn run_grpc_bidi_stream( Some(core_response::Payload::ClientMfaStart(response_payload)) } Err(err) => { - error!("client mfa start error {err}"); + error!("client MFA start error {err}"); Some(core_response::Payload::CoreError(err.into())) } } @@ -483,7 +483,7 @@ pub async fn run_grpc_bidi_stream( Some(core_response::Payload::ClientMfaFinish(response_payload)) } Err(err) => { - error!("client mfa start error {err}"); + error!("client MFA start error {err}"); Some(core_response::Payload::CoreError(err.into())) } } From 21481291ac1cac06174a3891f0670b9c9087bfeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 17 Jan 2024 11:21:04 +0100 Subject: [PATCH 09/11] use a JWT for auth --- src/auth/mod.rs | 2 ++ src/grpc/desktop_client_mfa.rs | 44 +++++++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 107a26799..6fc07679f 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -37,6 +37,7 @@ pub enum ClaimsType { Auth, Gateway, YubiBridge, + DesktopClient, } /// Standard claims: https://www.iana.org/assignments/jwt/jwt.xhtml @@ -85,6 +86,7 @@ impl Claims { ClaimsType::Auth => AUTH_SECRET_ENV, ClaimsType::Gateway => GATEWAY_SECRET_ENV, ClaimsType::YubiBridge => YUBIBRIDGE_SECRET_ENV, + ClaimsType::DesktopClient => AUTH_SECRET_ENV, }; env::var(env_var).unwrap_or_default() } diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index 828c0f282..60a1a3b14 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -1,4 +1,9 @@ +use super::proto::{ + ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, + MfaMethod, +}; use crate::{ + auth::{Claims, ClaimsType}, db::{ models::device::{DeviceInfo, DeviceNetworkInfo, WireguardNetworkDevice}, DbPool, Device, GatewayEvent, User, UserInfo, WireguardNetwork, @@ -9,12 +14,8 @@ use crate::{ use std::collections::HashMap; use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; use tonic::Status; -use uuid::Uuid; -use super::proto::{ - ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, - MfaMethod, -}; +const SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes struct ClientLoginSession { location: WireguardNetwork, @@ -43,6 +44,28 @@ impl ClientMfaServer { sessions: HashMap::new(), } } + fn generate_token(&self, pubkey: &str) -> Result { + Claims::new( + ClaimsType::DesktopClient, + String::new(), + pubkey.into(), + SESSION_TIMEOUT, + ) + .to_jwt() + .map_err(|err| { + error!("Failed to generate JWT token: {err:?}"); + Status::internal("unexpected error") + }) + } + + /// Validate JWT and extract client pubkey + fn parse_token(&self, token: &str) -> Result { + let claims = Claims::from_jwt(ClaimsType::DesktopClient, token).map_err(|err| { + error!("Failed to parse JWT token: {err:?}"); + Status::invalid_argument("invalid token") + })?; + Ok(claims.client_id) + } pub async fn start_client_mfa_login( &mut self, @@ -132,18 +155,18 @@ impl ClientMfaServer { } // generate auth token - let token = Uuid::new_v4().to_string(); + let token = self.generate_token(&request.pubkey)?; // store login session self.sessions.insert( - token.clone(), + request.pubkey, ClientLoginSession { location, device, user, }, ); - + Ok(ClientMfaStartResponse { token }) } @@ -152,8 +175,11 @@ impl ClientMfaServer { request: ClientMfaFinishRequest, ) -> Result { info!("Finishing desktop client login: {request:?}"); + // get pubkey from token + let pubkey = self.parse_token(&request.token)?; + // fetch login session - let Some(session) = self.sessions.remove(&request.token) else { + let Some(session) = self.sessions.remove(&pubkey) else { error!("Client login session not found"); return Err(Status::invalid_argument("login session not found")); }; From 4948cbef5ef320166147ee91ac04fc771bbdadef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 17 Jan 2024 11:42:25 +0100 Subject: [PATCH 10/11] validate TOTP code --- src/grpc/desktop_client_mfa.rs | 39 ++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/grpc/desktop_client_mfa.rs b/src/grpc/desktop_client_mfa.rs index 60a1a3b14..93556e41b 100644 --- a/src/grpc/desktop_client_mfa.rs +++ b/src/grpc/desktop_client_mfa.rs @@ -18,6 +18,7 @@ use tonic::Status; const SESSION_TIMEOUT: u64 = 60 * 5; // 10 minutes struct ClientLoginSession { + method: MfaMethod, location: WireguardNetwork, device: Device, user: User, @@ -123,8 +124,12 @@ impl ClientMfaServer { } // check if selected method is enabled - match MfaMethod::try_from(request.method) { - Ok(MfaMethod::Totp) => { + let method = MfaMethod::try_from(request.method).map_err(|err| { + error!("Invalid MFA method selected: {err}"); + Status::invalid_argument("invalid MFA method selected") + })?; + match method { + MfaMethod::Totp => { if !user.totp_enabled { error!("TOTP not enabled for user {}", user.username); return Err(Status::invalid_argument( @@ -132,7 +137,7 @@ impl ClientMfaServer { )); } } - Ok(MfaMethod::Email) => { + MfaMethod::Email => { if !user.email_mfa_enabled { error!("Email MFA not enabled for user {}", user.username); return Err(Status::invalid_argument( @@ -148,11 +153,7 @@ impl ClientMfaServer { Status::internal("unexpected error") })?; } - Err(err) => { - error!("Invalid MFA method selected: {err}"); - return Err(Status::invalid_argument("invalid MFA method selected")); - } - } + }; // generate auth token let token = self.generate_token(&request.pubkey)?; @@ -161,12 +162,13 @@ impl ClientMfaServer { self.sessions.insert( request.pubkey, ClientLoginSession { + method, location, device, user, }, ); - + Ok(ClientMfaStartResponse { token }) } @@ -184,15 +186,26 @@ impl ClientMfaServer { return Err(Status::invalid_argument("login session not found")); }; let ClientLoginSession { + method, device, location, user, } = session; - // validate email code - if !user.verify_email_mfa_code(request.code) { - error!("Provided email code is not valid"); - return Err(Status::unauthenticated("unauthorized")); + // validate code + match method { + MfaMethod::Totp => { + if !user.verify_totp_code(request.code) { + error!("Provided TOTP code is not valid"); + return Err(Status::unauthenticated("unauthorized")); + } + } + MfaMethod::Email => { + if !user.verify_email_mfa_code(request.code) { + error!("Provided email code is not valid"); + return Err(Status::unauthenticated("unauthorized")); + } + } }; // begin transaction From 6b7382c1fd3793d87447d5311521fc64b8a778f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 17 Jan 2024 11:51:02 +0100 Subject: [PATCH 11/11] update protos --- proto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proto b/proto index 3294ebd57..a5776f47e 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 3294ebd5748419ca604afbd6869f305ef7879e1c +Subproject commit a5776f47e0a9ffb7e408401662f3e5a4ced205d9