From 0d5805817a441866497164c26473ac4874516d10 Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:43:20 +0100 Subject: [PATCH 1/8] Fix ipv6 input and database adminid constraints (#851) * fix building docs * add libssl-dev dependency * Update README.md * add pkg-config dependency * add curl dependency * Update README.md * Update README.md * Update README.md * add on delete cascade to adminid in the token table * allow ipv6 input --------- Co-authored-by: Robert Olejnik --- .github/workflows/docs.yml | 2 +- README.md | 37 ++++++++------- .../20241108110157_add_on_delete.down.sql | 2 + .../20241108110157_add_on_delete.up.sql | 2 + web/package.json | 1 + web/pnpm-lock.yaml | 9 ++++ web/src/i18n/en/index.ts | 1 + web/src/i18n/i18n-types.ts | 8 ++++ web/src/i18n/pl/index.ts | 3 +- .../NetworkEditForm/NetworkEditForm.tsx | 47 ++++++++++++++----- .../SmtpSettingsForm/SmtpSettingsForm.tsx | 2 +- .../WizardNetworkConfiguration.tsx | 4 +- web/src/shared/patterns.ts | 3 -- web/src/shared/validators.ts | 47 +++++++++++++------ 14 files changed, 116 insertions(+), 52 deletions(-) create mode 100644 migrations/20241108110157_add_on_delete.down.sql create mode 100644 migrations/20241108110157_add_on_delete.up.sql diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 74f8db866..9f8796cb9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -19,7 +19,7 @@ jobs: container: rust:1-slim steps: - name: Install packages - run: apt-get update && apt install -y git protobuf-compiler libssl-dev + run: apt-get update && apt install -y git protobuf-compiler libssl-dev pkg-config curl - name: Checkout uses: actions/checkout@v4 diff --git a/README.md b/README.md index 61220a89f..81354e337 100644 --- a/README.md +++ b/README.md @@ -6,18 +6,13 @@ GitHub commits since latest release

-[Website](https://defguard.net) | [Getting Started](https://docs.defguard.net/#what-is-defguard) | [Features](https://github.com/defguard/defguard#features) | [Roadmap](https://github.com/orgs/defguard/projects/5) | [Support ❤](https://github.com/defguard/defguard#support-) +[Website](https://defguard.net) | [Getting Started](https://docs.defguard.net/#what-is-defguard) | [Features](https://github.com/defguard/defguard#features) | [Roadmap](https://github.com/orgs/defguard/projects/5) | [Support ❤](https://github.com/defguard/defguard#support) -## Enterprise features are here! - -🛑 We encourge to test the [pre-release](https://docs.defguard.net/admin-and-features/setting-up-your-instance/pre-production-and-development-releases) of the new **Open Source Open Core** & **Enterprise features** (like external OpenID (Google/Microsoft/Custom), real time client sync and more!) published! 🛑 - -All currently available enterprise features are in [enterprise documentation section](https://docs.defguard.net/enterprise/all-enteprise-features) as well as information about [enterprise license](https://docs.defguard.net/enterprise/license). -### Unique value proposition +### Comprehensive Access Control -- **Comprehensive [WireGuard® 2FA/MFA](https://docs.defguard.net/admin-and-features/wireguard/multi-factor-authentication-mfa-2fa/architecture)** - not 2FA to "access application" like most solutions +- **[WireGuard® VPN with 2FA/MFA](https://docs.defguard.net/admin-and-features/wireguard/multi-factor-authentication-mfa-2fa/architecture)** - not 2FA to "access application" like most solutions - The only solution with [automatic and real-time synchronization](https://docs.defguard.net/enterprise/automatic-real-time-desktop-client-configuration) for users' desktop client settings (including all VPNs/locations). - Control users [ability to manage devices and VPN options](https://docs.defguard.net/enterprise/behavior-customization) - [Integrated SSO based on OpenID Connect](https://docs.defguard.net/admin-and-features/openid-connect): @@ -31,7 +26,9 @@ All currently available enterprise features are in [enterprise documentation sec - Built on WireGuard® protocol which is faster than IPSec, and significantly faster than OpenVPN - Built with Rust for speed and security -See below [full list of features](https://github.com/defguard/defguard#features) +See: +- [full list of features](https://github.com/defguard/defguard#features) +- [enterprise only features](https://docs.defguard.net/enterprise/all-enteprise-features) #### Video introduction @@ -65,6 +62,8 @@ Better quality video can [be viewed here](https://github.com/DefGuard/docs/raw/d [Desktop client](https://github.com/DefGuard/client): - **2FA / Multi-Factor Authentication** with TOTP or email based tokens & WireGuard PSK +- [automatic and real-time synchronization](https://docs.defguard.net/enterprise/automatic-real-time-desktop-client-configuration) for users' desktop client settings (including all VPNs/locations). +- Control users [ability to manage devices and VPN options](https://docs.defguard.net/enterprise/behavior-customization) - Defguard instances as well as **any WireGuard tunnel** - just import your tunnels - one client for all WireGuard connections - Secure and remote user enrollment - setting up password, automatically configuring the client for all VPN Locations/Networks - Onboarding - displaying custom onboarding messages, with templates, links ... @@ -117,7 +116,7 @@ The story and motivation behind defguard [can be found here: https://teonite.com ## Features -* [WireGuard®](https://www.wireguard.com/) VPN server with: +* Remote Access: [WireGuard® VPN](https://www.wireguard.com/) server with: - [Multi-Factor Authentication](https://docs.defguard.net/help/desktop-client/multi-factor-authentication-mfa-2fa) with TOTP/Email & Pre-Shared Session Keys - multiple VPN Locations (networks/sites) - with defined access (all users or only Admin group) - multiple [Gateways](https://github.com/DefGuard/gateway) for each VPN Location (**high availability/failover**) - supported on a cluster of routers/firewalls for Linux, FreeBSD/PFSense/OPNSense @@ -129,18 +128,20 @@ The story and motivation behind defguard [can be found here: https://teonite.com - kernel (Linux, FreeBSD/OPNSense/PFSense) & userspace WireGuard® support with [our Rust library](https://github.com/defguard/wireguard-rs) - dashboard and statistics overview of connected users/devices for admins - *defguard is not an official WireGuard® project, and WireGuard is a registered trademark of Jason A. Donenfeld.* -* Integrated SSO: [OpenID Connect provider](https://openid.net/developers/how-connect-works/) - with **unique features**: - - Secure remote (over the internet) [user enrollment](https://docs.defguard.net/help/remote-user-enrollment) - - User [onboarding after enrollment](https://docs.defguard.net/help/remote-user-enrollment/user-onboarding-after-enrollment) - - LDAP (tested on [OpenLDAP](https://www.openldap.org/)) synchronization - - [forward auth](https://docs.defguard.net/features/forward-auth) for reverse proxies (tested with Traefik and Caddy) - - nice UI to manage users - - Users **self-service** (besides typical data management, users can revoke access to granted apps, MFA, WireGuard®, etc.) +* Identity & Account Management: + - SSO based on OpenID Connect](https://openid.net/developers/how-connect-works/) + - Extenal SSO: [external OpenID provider support](https://docs.defguard.net/enterprise/external-openid-providers) - [Multi-Factor/2FA](https://en.wikipedia.org/wiki/Multi-factor_authentication) Authentication: - [Time-based One-Time Password Algorithm](https://en.wikipedia.org/wiki/Time-based_one-time_password) (TOTP - e.g. Google Authenticator) - WebAuthn / FIDO2 - for hardware key authentication support (eg. YubiKey, FaceID, TouchID, ...) - Email based TOTP -* Extenal SSO: [external OpenID provider support](https://docs.defguard.net/enterprise/external-openid-providers) + - LDAP (tested on [OpenLDAP](https://www.openldap.org/)) synchronization + - [forward auth](https://docs.defguard.net/features/forward-auth) for reverse proxies (tested with Traefik and Caddy) + - nice UI to manage users + - Users **self-service** (besides typical data management, users can revoke access to granted apps, MFA, WireGuard®, etc.) +* Account Lifecycle Management: + - Secure remote (over the Internet) [user enrollment](https://docs.defguard.net/help/remote-user-enrollment) - on public web / Desktop Client + - User [onboarding after enrollment](https://docs.defguard.net/help/remote-user-enrollment/user-onboarding-after-enrollment) * SSH & GPG public key management in user profile - with [SSH keys authentication for servers](https://docs.defguard.net/admin-and-features/ssh-authentication) * [Yubikey hardware keys](https://www.yubico.com/) provisioning for users by *one click* * [Email/SMTP support](https://docs.defguard.net/help/setting-up-smtp-for-email-notifications) for notifications, remote enrollment and onboarding diff --git a/migrations/20241108110157_add_on_delete.down.sql b/migrations/20241108110157_add_on_delete.down.sql new file mode 100644 index 000000000..f7680821b --- /dev/null +++ b/migrations/20241108110157_add_on_delete.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE token DROP CONSTRAINT enrollment_admin_id_fkey; +ALTER TABLE token ADD CONSTRAINT enrollment_admin_id_fkey FOREIGN KEY(admin_id) REFERENCES "user"(id); diff --git a/migrations/20241108110157_add_on_delete.up.sql b/migrations/20241108110157_add_on_delete.up.sql new file mode 100644 index 000000000..3ade0c10c --- /dev/null +++ b/migrations/20241108110157_add_on_delete.up.sql @@ -0,0 +1,2 @@ +ALTER TABLE token DROP CONSTRAINT enrollment_admin_id_fkey; +ALTER TABLE token ADD CONSTRAINT enrollment_admin_id_fkey FOREIGN KEY(admin_id) REFERENCES "user"(id) ON DELETE CASCADE; diff --git a/web/package.json b/web/package.json index c7ed2738c..544013ef8 100644 --- a/web/package.json +++ b/web/package.json @@ -71,6 +71,7 @@ "get-text-width": "^1.0.3", "hex-rgb": "^5.0.0", "html-react-parser": "^5.1.1", + "ipaddr.js": "^2.2.0", "itertools": "^2.2.3", "lodash-es": "^4.17.21", "numbro": "^2.4.0", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index ba7fe4b1a..468706420 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -101,6 +101,9 @@ importers: html-react-parser: specifier: ^5.1.1 version: 5.1.1(react@18.2.0) + ipaddr.js: + specifier: ^2.2.0 + version: 2.2.0 itertools: specifier: ^2.2.3 version: 2.2.3 @@ -3268,6 +3271,10 @@ packages: resolution: {integrity: sha512-5Hh7Y1wQbvY5ooGgPbDaL5iYLAPzMTUrjMulskHLH6wnv/A+1q5rgEaiuqEjB+oxGXIVZs1FF+R/KPN3ZSQYYg==} engines: {node: '>=12'} + ipaddr.js@2.2.0: + resolution: {integrity: sha512-Ag3wB2o37wslZS19hZqorUnrnzSkpOVy+IiiDEiTqNubEYpYuHWIf6K4psgN2ZWKExS4xhVCrRVfb/wfW8fWJA==} + engines: {node: '>= 10'} + is-alphabetical@2.0.1: resolution: {integrity: sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==} @@ -8814,6 +8821,8 @@ snapshots: internmap@2.0.3: {} + ipaddr.js@2.2.0: {} + is-alphabetical@2.0.1: {} is-alphanumerical@2.0.1: diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index a94620912..ed010dc60 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -856,6 +856,7 @@ const en: BaseTranslation = { portMax: 'Maximum port is 65535.', endpoint: 'Enter a valid endpoint.', address: 'Enter a valid address.', + addressNetmask: 'Enter a valid address with a netmask.', validPort: 'Enter a valid port.', validCode: 'Code should have 6 digits.', allowedIps: 'Only valid IP or domain is allowed.', diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index 524f84697..c70aa7afe 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -2101,6 +2101,10 @@ type RootTranslation = { * E​n​t​e​r​ ​a​ ​v​a​l​i​d​ ​a​d​d​r​e​s​s​. */ address: string + /** + * E​n​t​e​r​ ​a​ ​v​a​l​i​d​ ​a​d​d​r​e​s​s​ ​w​i​t​h​ ​a​ ​n​e​t​m​a​s​k​. + */ + addressNetmask: string /** * E​n​t​e​r​ ​a​ ​v​a​l​i​d​ ​p​o​r​t​. */ @@ -6356,6 +6360,10 @@ export type TranslationFunctions = { * Enter a valid address. */ address: () => LocalizedString + /** + * Enter a valid address with a netmask. + */ + addressNetmask: () => LocalizedString /** * Enter a valid port. */ diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index 17045129f..b367c0c91 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -841,8 +841,9 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe oneUppercase: 'Wymagana jedna duża litera.', oneLowercase: 'Wymagana jedna mała litera.', portMax: 'Maksymalny numer portu to 65535.', - endpoint: 'Wpisz prawidłowy punkt końcowy.', + endpoint: 'Wpisz poprawny adres.', address: 'Wprowadź poprawny adres.', + addressNetmask: 'Wprowadź poprawny adres IP oraz maskę sieci.', validPort: 'Wprowadź prawidłowy port.', validCode: 'Kod powinien mieć 6 cyfr.', allowedIps: 'Tylko poprawne adresy IP oraz domeny.', diff --git a/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx b/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx index 32875bcc9..33c358940 100644 --- a/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx +++ b/web/src/pages/network/NetworkEditForm/NetworkEditForm.tsx @@ -2,6 +2,7 @@ import './style.scss'; import { zodResolver } from '@hookform/resolvers/zod'; import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import ipaddr from 'ipaddr.js'; import { isNull, omit, omitBy } from 'lodash-es'; import { useEffect, useMemo, useRef, useState } from 'react'; import { SubmitHandler, useForm } from 'react-hook-form'; @@ -20,11 +21,7 @@ import { QueryKeys } from '../../../shared/queries'; import { Network } from '../../../shared/types'; import { titleCase } from '../../../shared/utils/titleCase'; import { trimObjectStrings } from '../../../shared/utils/trimObjectStrings.ts'; -import { - validateIp, - validateIpOrDomain, - validateIpOrDomainList, -} from '../../../shared/validators'; +import { validateIpOrDomain, validateIpOrDomainList } from '../../../shared/validators'; import { useNetworkPageStore } from '../hooks/useNetworkPageStore'; type FormFields = { @@ -155,17 +152,43 @@ export const NetworkEditForm = () => { if (!netmaskPresent) { return false; } - const ipValid = validateIp(value, true); - if (ipValid) { - const host = value.split('.')[3].split('/')[0]; - if (host === '0') return false; + const ipValid = ipaddr.isValidCIDR(value); + if (!ipValid) { + return false; + } + const [address] = ipaddr.parseCIDR(value); + if (address.kind() === 'ipv6') { + const networkAddress = ipaddr.IPv6.networkAddressFromCIDR(value); + const broadcastAddress = ipaddr.IPv6.broadcastAddressFromCIDR(value); + if ( + (address as ipaddr.IPv6).toNormalizedString() === + networkAddress.toNormalizedString() || + (address as ipaddr.IPv6).toNormalizedString() === + broadcastAddress.toNormalizedString() + ) { + return false; + } + } else { + const networkAddress = ipaddr.IPv4.networkAddressFromCIDR(value); + const broadcastAddress = ipaddr.IPv4.broadcastAddressFromCIDR(value); + if ( + (address as ipaddr.IPv4).toNormalizedString() === + networkAddress.toNormalizedString() || + (address as ipaddr.IPv4).toNormalizedString() === + broadcastAddress.toNormalizedString() + ) { + return false; + } } return ipValid; - }, LL.form.error.address()), + }, LL.form.error.addressNetmask()), endpoint: z .string() .min(1, LL.form.error.required()) - .refine((val) => validateIpOrDomain(val), LL.form.error.endpoint()), + .refine( + (val) => validateIpOrDomain(val, false, true), + LL.form.error.endpoint(), + ), port: z .number({ invalid_type_error: LL.form.error.required(), @@ -179,7 +202,7 @@ export const NetworkEditForm = () => { if (val === '' || !val) { return true; } - return validateIpOrDomainList(val, ',', true); + return validateIpOrDomainList(val, ',', false, true); }, LL.form.error.allowedIps()), allowed_groups: z.array(z.string().min(1, LL.form.error.minimumLength())), mfa_enabled: z.boolean(), diff --git a/web/src/pages/settings/components/SmtpSettings/components/SmtpSettingsForm/SmtpSettingsForm.tsx b/web/src/pages/settings/components/SmtpSettings/components/SmtpSettingsForm/SmtpSettingsForm.tsx index 13bf87296..ba22409d1 100644 --- a/web/src/pages/settings/components/SmtpSettings/components/SmtpSettingsForm/SmtpSettingsForm.tsx +++ b/web/src/pages/settings/components/SmtpSettings/components/SmtpSettingsForm/SmtpSettingsForm.tsx @@ -104,7 +104,7 @@ export const SmtpSettingsForm = () => { .string() .min(1, LL.form.error.required()) .refine( - (val) => (!val ? true : validateIpOrDomain(val)), + (val) => (!val ? true : validateIpOrDomain(val, false, true)), LL.form.error.endpoint(), ), smtp_port: z diff --git a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx index fbd12098e..cce80f6d5 100644 --- a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx +++ b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx @@ -20,7 +20,7 @@ import { QueryKeys } from '../../../../shared/queries'; import { ModifyNetworkRequest } from '../../../../shared/types'; import { titleCase } from '../../../../shared/utils/titleCase'; import { trimObjectStrings } from '../../../../shared/utils/trimObjectStrings.ts'; -import { validateIp, validateIpOrDomainList } from '../../../../shared/validators'; +import { validateIpOrDomainList, validateIPv4 } from '../../../../shared/validators'; import { useWizardStore } from '../../hooks/useWizardStore'; type FormInputs = ModifyNetworkRequest['network']; @@ -91,7 +91,7 @@ export const WizardNetworkConfiguration = () => { if (!netmaskPresent) { return false; } - const ipValid = validateIp(value, true); + const ipValid = validateIPv4(value, true); if (ipValid) { const host = value.split('.')[3].split('/')[0]; if (host === '0') return false; diff --git a/web/src/shared/patterns.ts b/web/src/shared/patterns.ts index 1cb06ea8c..6c17b23c9 100644 --- a/web/src/shared/patterns.ts +++ b/web/src/shared/patterns.ts @@ -68,9 +68,6 @@ export const patternValidUrl = new RegExp( export const patternValidDomain = /^(?:(?:(?:[a-zA-z\-]+)\:\/{1,3})?(?:[a-zA-Z0-9])(?:[a-zA-Z0-9\-\.]){1,61}(?:\.[a-zA-Z]{2,})+|\[(?:(?:(?:[a-fA-F0-9]){1,4})(?::(?:[a-fA-F0-9]){1,4}){7}|::1|::)\]|(?:(?:[0-9]{1,3})(?:\.[0-9]{1,3}){3}))(?:\:[0-9]{1,5})?$/; -export const patternValidIp = - /^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$/; - export const patternSafeUsernameCharacters = /^[a-zA-Z0-9]+[a-zA-Z0-9.\-_]+$/; export const patternSafePasswordCharacters = diff --git a/web/src/shared/validators.ts b/web/src/shared/validators.ts index 92e538863..c9aa3c3a4 100644 --- a/web/src/shared/validators.ts +++ b/web/src/shared/validators.ts @@ -1,8 +1,18 @@ -import { patternValidDomain, patternValidIp } from './patterns'; +import ipaddr from 'ipaddr.js'; + +import { patternValidDomain } from './patterns'; // Returns flase when invalid -export const validateIpOrDomain = (val: string, allowMask = false): boolean => { - return validateIp(val, allowMask) || patternValidDomain.test(val); +export const validateIpOrDomain = ( + val: string, + allowMask = false, + allowIPv6 = false, +): boolean => { + return ( + (allowIPv6 && validateIPv6(val, allowMask)) || + validateIPv4(val, allowMask) || + patternValidDomain.test(val) + ); }; // Returns flase when invalid @@ -14,7 +24,7 @@ export const validateIpList = ( const trimed = val.replace(' ', ''); const split = trimed.split(splitWith); for (const value of split) { - if (!validateIp(value, allowMasks)) { + if (!validateIPv4(value, allowMasks)) { return false; } } @@ -26,11 +36,17 @@ export const validateIpOrDomainList = ( val: string, splitWith = ',', allowMasks = false, + allowIPv6 = false, ): boolean => { const trimed = val.replace(' ', ''); const split = trimed.split(splitWith); for (const value of split) { - if (!validateIp(value, allowMasks) && !patternValidDomain.test(value)) { + console.log(allowIPv6 && !validateIPv6(value, allowMasks)); + if ( + !validateIPv4(value, allowMasks) && + !patternValidDomain.test(value) && + (!allowIPv6 || !validateIPv6(value, allowMasks)) + ) { return false; } } @@ -38,19 +54,22 @@ export const validateIpOrDomainList = ( }; // Returns flase when invalid -export const validateIp = (ip: string, allowMask = false): boolean => { +export const validateIPv4 = (ip: string, allowMask = false): boolean => { + if (allowMask) { + if (ip.includes('/')) { + ipaddr.IPv4.isValidCIDR(ip); + } + } + return ipaddr.IPv4.isValid(ip); +}; + +export const validateIPv6 = (ip: string, allowMask = false): boolean => { if (allowMask) { if (ip.includes('/')) { - const split = ip.split('/'); - if (split.length !== 2) return true; - const ipValid = patternValidIp.test(split[0]); - if (split[1] === '') return false; - const mask = Number(split[1]); - const maskValid = mask >= 0 && mask <= 32; - return ipValid && maskValid; + ipaddr.IPv6.isValidCIDR(ip); } } - return patternValidIp.test(ip); + return ipaddr.IPv6.isValid(ip); }; export const validatePort = (val: string) => { From 47836384f76fe83a2f28d1d7c23a5b2dc9ac115a Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:33:13 +0100 Subject: [PATCH 2/8] Enable enterprise features when within certain limits (#852) * don't require license if not exceeding the limits * update counts, make basic frontend * code cleanup * cleanup part 2 * sqlx prepare * add limits info * add scopes to tests * polish translation * fix tests * fix frontend tests * fix frontend tests 2 --- ...c938dd921c3288849fdd46fc760ce3dd21882.json | 32 ++++ src/bin/defguard.rs | 7 +- .../db/models/enterprise_settings.rs | 8 +- src/enterprise/grpc/polling.rs | 6 +- src/enterprise/handlers/mod.rs | 27 ++-- src/enterprise/limits.rs | 142 ++++++++++++++++++ src/enterprise/mod.rs | 24 +++ src/grpc/mod.rs | 7 +- src/handlers/app_info.rs | 5 +- src/handlers/user.rs | 3 + src/handlers/wireguard.rs | 10 +- tests/common/mod.rs | 40 +++++ tests/enterprise_settings.rs | 7 + tests/openid_login.rs | 3 + web/src/i18n/en/index.ts | 2 + web/src/i18n/i18n-types.ts | 8 + web/src/i18n/pl/index.ts | 2 + .../EnterpriseSettings/EnterpriseSettings.tsx | 10 ++ .../LicenseSettings/LicenseSettings.tsx | 4 +- .../components/LicenseSettings/styles.scss | 4 + .../OpenIdSettings/OpenIdSettings.tsx | 10 ++ web/src/pages/settings/style.scss | 6 +- .../UsersList/components/UsersListGroups.tsx | 2 +- web/src/shared/defguard-ui | 2 +- web/src/shared/types.ts | 1 + 25 files changed, 339 insertions(+), 33 deletions(-) create mode 100644 .sqlx/query-d5b2165ab0cd9e32296dcfb4e4bc938dd921c3288849fdd46fc760ce3dd21882.json create mode 100644 src/enterprise/limits.rs diff --git a/.sqlx/query-d5b2165ab0cd9e32296dcfb4e4bc938dd921c3288849fdd46fc760ce3dd21882.json b/.sqlx/query-d5b2165ab0cd9e32296dcfb4e4bc938dd921c3288849fdd46fc760ce3dd21882.json new file mode 100644 index 000000000..7212f259f --- /dev/null +++ b/.sqlx/query-d5b2165ab0cd9e32296dcfb4e4bc938dd921c3288849fdd46fc760ce3dd21882.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT (SELECT count(*) FROM \"user\") \"user!\", (SELECT count(*) FROM device) \"device!\", (SELECT count(*) FROM wireguard_network) \"wireguard_network!\"\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "user!", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "device!", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "wireguard_network!", + "type_info": "Int8" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null, + null, + null + ] + }, + "hash": "d5b2165ab0cd9e32296dcfb4e4bc938dd921c3288849fdd46fc760ce3dd21882" +} diff --git a/src/bin/defguard.rs b/src/bin/defguard.rs index a25f50a83..ad7565d74 100644 --- a/src/bin/defguard.rs +++ b/src/bin/defguard.rs @@ -7,7 +7,10 @@ use defguard::{ auth::failed_login::FailedLoginMap, config::{Command, DefGuardConfig}, db::{init_db, AppEvent, GatewayEvent, Settings, User}, - enterprise::license::{run_periodic_license_check, set_cached_license, License}, + enterprise::{ + license::{run_periodic_license_check, set_cached_license, License}, + limits::update_counts, + }, grpc::{run_grpc_bidi_stream, run_grpc_server, GatewayMap, WorkerState}, headers::create_user_agent_parser, init_dev_env, init_vpn_location, @@ -101,6 +104,8 @@ async fn main() -> Result<(), anyhow::Error> { let failed_logins = FailedLoginMap::new(); let failed_logins = Arc::new(Mutex::new(failed_logins)); + update_counts(&pool).await?; + debug!("Checking enterprise license status"); match License::load_or_renew(&pool).await { Ok(license) => { diff --git a/src/enterprise/db/models/enterprise_settings.rs b/src/enterprise/db/models/enterprise_settings.rs index 1d14b761c..81a53a20f 100644 --- a/src/enterprise/db/models/enterprise_settings.rs +++ b/src/enterprise/db/models/enterprise_settings.rs @@ -1,7 +1,7 @@ use sqlx::{query, query_as, PgExecutor}; use struct_patch::Patch; -use crate::enterprise::license::{get_cached_license, validate_license}; +use crate::enterprise::is_enterprise_enabled; #[derive(Debug, Deserialize, Patch, Serialize)] #[patch(attribute(derive(Deserialize, Serialize)))] @@ -34,11 +34,7 @@ impl EnterpriseSettings { { // avoid holding the rwlock across await, makes the future !Send // and therefore unusable in axum handlers - let is_valid = { - let license = get_cached_license(); - validate_license(license.as_ref()).is_ok() - }; - if is_valid { + if is_enterprise_enabled() { let settings = query_as!( Self, "SELECT admin_device_management, \ diff --git a/src/enterprise/grpc/polling.rs b/src/enterprise/grpc/polling.rs index 1bf322566..3ea1c1e33 100644 --- a/src/enterprise/grpc/polling.rs +++ b/src/enterprise/grpc/polling.rs @@ -3,7 +3,7 @@ use tonic::Status; use crate::{ db::{models::polling_token::PollingToken, Device, Id, User}, - enterprise::license::{get_cached_license, validate_license}, + enterprise::is_enterprise_enabled, grpc::{ proto::{InstanceInfoRequest, InstanceInfoResponse}, utils::build_device_config_response, @@ -25,8 +25,8 @@ impl PollingServer { debug!("Validating polling token. Token: {token}"); // Polling service is enterprise-only, check the lincense - if validate_license(get_cached_license().as_ref()).is_err() { - debug!("No valid license, denying instance polling info"); + if !is_enterprise_enabled() { + debug!("Instance has enterprise features disabled, denying instance polling info"); return Err(Status::failed_precondition("no valid license")); } diff --git a/src/enterprise/handlers/mod.rs b/src/enterprise/handlers/mod.rs index fb5bba483..651a175bd 100644 --- a/src/enterprise/handlers/mod.rs +++ b/src/enterprise/handlers/mod.rs @@ -1,6 +1,5 @@ use crate::{ auth::SessionInfo, - enterprise::license::validate_license, handlers::{ApiResponse, ApiResult}, }; @@ -14,7 +13,10 @@ use axum::{ http::{request::Parts, StatusCode}, }; -use super::{db::models::enterprise_settings::EnterpriseSettings, license::get_cached_license}; +use super::{ + db::models::enterprise_settings::EnterpriseSettings, is_enterprise_enabled, + license::get_cached_license, needs_enterprise_license, +}; use crate::{appstate::AppState, error::WebError}; pub struct LicenseInfo { @@ -33,19 +35,20 @@ where type Rejection = WebError; async fn from_request_parts(_parts: &mut Parts, _state: &S) -> Result { - let license = get_cached_license(); - - match validate_license(license.as_ref()) { - // Useless struct, but may come in handy later - Ok(()) => Ok(LicenseInfo { valid: true }), - Err(e) => Err(WebError::Forbidden(e.to_string())), + if is_enterprise_enabled() { + Ok(LicenseInfo { valid: true }) + } else { + Err(WebError::Forbidden( + "Enterprise features are disabled".into(), + )) } } } pub async fn check_enterprise_status() -> ApiResult { + let enterprise_enabled = is_enterprise_enabled(); + let needs_license = needs_enterprise_license(); let license = get_cached_license(); - let valid = validate_license((license).as_ref()).is_ok(); let license_info = license.as_ref().map(|license| { serde_json::json!( { @@ -55,8 +58,10 @@ pub async fn check_enterprise_status() -> ApiResult { ) }); Ok(ApiResponse { - json: serde_json::json!({ "enabled": valid, - "license_info": license_info + json: serde_json::json!({ + "enabled": enterprise_enabled, + "needs_license": needs_license, + "license_info": license_info }), status: StatusCode::OK, }) diff --git a/src/enterprise/limits.rs b/src/enterprise/limits.rs new file mode 100644 index 000000000..a9ebfe724 --- /dev/null +++ b/src/enterprise/limits.rs @@ -0,0 +1,142 @@ +use sqlx::{error::Error as SqlxError, query_as, PgPool}; +use std::sync::{RwLock, RwLockReadGuard}; + +#[derive(Debug)] +pub(crate) struct Counts { + user: i64, + device: i64, + wireguard_network: i64, +} + +static COUNTS: RwLock = RwLock::new(Counts { + user: 0, + device: 0, + wireguard_network: 0, +}); + +fn set_counts(new_counts: Counts) { + *COUNTS + .write() + .expect("Failed to acquire lock on the enterprise limit counts.") = new_counts; +} + +pub(crate) fn get_counts() -> RwLockReadGuard<'static, Counts> { + COUNTS + .read() + .expect("Failed to acquire lock on the enterprise limit counts.") +} + +/// Update the counts of users, devices, and wireguard networks stored in the memory. +// TODO: Use it with database triggers when they are implemented +pub async fn update_counts(pool: &PgPool) -> Result<(), SqlxError> { + debug!("Updating device, user, and wireguard network counts."); + let counts = query_as!( + Counts, + "SELECT \ + (SELECT count(*) FROM \"user\") \"user!\", \ + (SELECT count(*) FROM device) \"device!\", \ + (SELECT count(*) FROM wireguard_network) \"wireguard_network!\" + " + ) + .fetch_one(pool) + .await?; + + set_counts(counts); + debug!( + "Updated device, user, and wireguard network counts stored in memory, new counts: {:?}", + get_counts() + ); + + Ok(()) +} + +impl Counts { + pub(crate) fn is_over_limit(&self) -> bool { + self.user > 5 || self.device > 10 || self.wireguard_network > 1 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_counts() { + let counts = Counts { + user: 1, + device: 2, + wireguard_network: 3, + }; + + set_counts(counts); + + let counts = get_counts(); + + assert_eq!(counts.user, 1); + assert_eq!(counts.device, 2); + assert_eq!(counts.wireguard_network, 3); + } + + #[test] + fn test_is_over_limit() { + // User limit + { + let counts = Counts { + user: 6, + device: 1, + wireguard_network: 1, + }; + set_counts(counts); + let counts = get_counts(); + assert!(counts.is_over_limit()); + } + + // Device limit + { + let counts = Counts { + user: 1, + device: 11, + wireguard_network: 1, + }; + set_counts(counts); + let counts = get_counts(); + assert!(counts.is_over_limit()); + } + + // Wireguard network limit + { + let counts = Counts { + user: 1, + device: 1, + wireguard_network: 2, + }; + set_counts(counts); + let counts = get_counts(); + assert!(counts.is_over_limit()); + } + + // No limit + { + let counts = Counts { + user: 1, + device: 1, + wireguard_network: 1, + }; + set_counts(counts); + let counts = get_counts(); + assert!(!counts.is_over_limit()); + } + + // All limits + { + let counts = Counts { + user: 6, + device: 11, + wireguard_network: 2, + }; + set_counts(counts); + let counts = get_counts(); + assert!(counts.is_over_limit()); + } + } +} diff --git a/src/enterprise/mod.rs b/src/enterprise/mod.rs index dadca5b4c..3c4c5adfa 100644 --- a/src/enterprise/mod.rs +++ b/src/enterprise/mod.rs @@ -2,3 +2,27 @@ pub mod db; pub mod grpc; pub mod handlers; pub mod license; +pub mod limits; +use license::{get_cached_license, validate_license}; +use limits::get_counts; + +pub(crate) fn needs_enterprise_license() -> bool { + get_counts().is_over_limit() +} + +pub(crate) fn is_enterprise_enabled() -> bool { + debug!("Checking if enterprise is enabled"); + match needs_enterprise_license() { + true => { + debug!("User is over limit, checking his license"); + let license = get_cached_license(); + let validation_result = validate_license(license.as_ref()); + debug!("License validation result: {:?}", validation_result); + validation_result.is_ok() + } + false => { + debug!("User is not over limit, allowing enterprise features"); + true + } + } +} diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index ea2a26cbd..802eb398c 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -48,9 +48,8 @@ use crate::{ auth::failed_login::FailedLoginMap, db::{AppEvent, Id, Settings}, enterprise::{ - db::models::enterprise_settings::EnterpriseSettings, - grpc::polling::PollingServer, - license::{get_cached_license, validate_license}, + db::models::enterprise_settings::EnterpriseSettings, grpc::polling::PollingServer, + is_enterprise_enabled, }, handlers::mail::send_gateway_disconnected_email, mail::Mail, @@ -679,7 +678,7 @@ impl InstanceInfo { proxy_url: config.enrollment_url.clone(), username: username.into(), disable_all_traffic: enterprise_settings.disable_all_traffic, - enterprise_enabled: validate_license(get_cached_license().as_ref()).is_ok(), + enterprise_enabled: is_enterprise_enabled(), } } } diff --git a/src/handlers/app_info.rs b/src/handlers/app_info.rs index a8bc4294c..4cc3cb642 100644 --- a/src/handlers/app_info.rs +++ b/src/handlers/app_info.rs @@ -6,7 +6,7 @@ use crate::{ appstate::AppState, auth::SessionInfo, db::{Settings, WireguardNetwork}, - enterprise::license::{get_cached_license, validate_license}, + enterprise::is_enterprise_enabled, }; /// Additional information about core state. @@ -25,8 +25,7 @@ pub(crate) async fn get_app_info( ) -> ApiResult { let networks = WireguardNetwork::all(&appstate.pool).await?; let settings = Settings::get_settings(&appstate.pool).await?; - let license = get_cached_license(); - let enterprise = validate_license((license).as_ref()).is_ok(); + let enterprise = is_enterprise_enabled(); let res = AppInfo { network_present: !networks.is_empty(), smtp_enabled: settings.smtp_configured(), diff --git a/src/handlers/user.rs b/src/handlers/user.rs index 51e88154f..9365d61ea 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -22,6 +22,7 @@ use crate::{ AppEvent, GatewayEvent, MFAMethod, OAuth2AuthorizedApp, Settings, User, UserDetails, UserInfo, Wallet, WebAuthn, WireguardNetwork, }, + enterprise::limits::update_counts, error::WebError, ldap::utils::{ldap_add_user, ldap_change_password, ldap_delete_user, ldap_modify_user}, mail::Mail, @@ -336,6 +337,7 @@ pub async fn add_user( ) .save(&appstate.pool) .await?; + update_counts(&appstate.pool).await?; if let Some(password) = user_data.password { let _result = ldap_add_user(&appstate.pool, &user, &password).await; @@ -734,6 +736,7 @@ pub async fn delete_user( let _result = ldap_delete_user(&mut *transaction, &username).await; appstate.trigger_action(AppEvent::UserDeleted(username.clone())); transaction.commit().await?; + update_counts(&appstate.pool).await?; info!("User {} deleted user {}", session.user.username, &username); Ok(ApiResponse::default()) diff --git a/src/handlers/wireguard.rs b/src/handlers/wireguard.rs index 149433e2d..13fad84dd 100644 --- a/src/handlers/wireguard.rs +++ b/src/handlers/wireguard.rs @@ -29,7 +29,7 @@ use crate::{ }, AddDevice, Device, GatewayEvent, Id, WireguardNetwork, }, - enterprise::handlers::CanManageDevices, + enterprise::{handlers::CanManageDevices, limits::update_counts}, grpc::GatewayMap, handlers::mail::send_new_device_added_email, server_config, @@ -135,6 +135,7 @@ pub async fn create_network( "User {} created WireGuard network {network_name}", session.user.username ); + update_counts(&appstate.pool).await?; Ok(ApiResponse { json: json!(network), @@ -218,6 +219,7 @@ pub async fn delete_network( "User {} deleted WireGuard network {network_id}", session.user.username, ); + update_counts(&appstate.pool).await?; Ok(ApiResponse::default()) } @@ -374,6 +376,8 @@ pub async fn import_network( info!("Imported network {network} with {} devices", devices.len()); + update_counts(&appstate.pool).await?; + Ok(ApiResponse { json: json!(ImportedNetworkData { network, devices }), status: StatusCode::CREATED, @@ -419,6 +423,7 @@ pub async fn add_user_devices( "User {} mapped {device_count} devices for {network_id} network", user.username, ); + update_counts(&appstate.pool).await?; Ok(ApiResponse { json: json!({}), @@ -592,6 +597,8 @@ pub async fn add_device( let result = AddDeviceResult { configs, device }; + update_counts(&appstate.pool).await?; + Ok(ApiResponse { json: json!(result), status: StatusCode::CREATED, @@ -762,6 +769,7 @@ pub async fn delete_device( )); device.delete(&appstate.pool).await?; info!("User {} deleted device {device_id}", session.user.username); + update_counts(&appstate.pool).await?; Ok(ApiResponse::default()) } diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 394bedce5..c0eef8a4e 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -9,12 +9,14 @@ use defguard::{ db::{init_db, AppEvent, GatewayEvent, Id, User, UserDetails}, enterprise::license::{set_cached_license, License}, grpc::{GatewayMap, WorkerState}, + handlers::Auth, headers::create_user_agent_parser, mail::Mail, SERVER_CONFIG, }; use reqwest::{header::HeaderName, StatusCode}; use secrecy::ExposeSecret; +use serde_json::json; use sqlx::{postgres::PgConnectOptions, query, types::Uuid, PgPool}; use tokio::sync::{ broadcast::{self, Receiver}, @@ -184,3 +186,41 @@ pub async fn fetch_user_details(client: &TestClient, username: &str) -> UserDeta assert_eq!(response.status(), StatusCode::OK); response.json().await } + +pub async fn exceed_enterprise_limits(client: &TestClient) { + let auth = Auth::new("admin", "pass123"); + client.post("/api/v1/auth").json(&auth).send().await; + client + .post("/api/v1/network") + .json(&json!({ + "name": "network1", + "address": "10.1.1.1/24", + "port": 55555, + "endpoint": "192.168.4.14", + "allowed_ips": "10.1.1.0/24", + "dns": "1.1.1.1", + "allowed_groups": [], + "mfa_enabled": false, + "keepalive_interval": 25, + "peer_disconnect_threshold": 180 + })) + .send() + .await; + + client + .post("/api/v1/network") + .json(&json!({ + "name": "network2", + "address": "10.1.1.1/24", + "port": 55555, + "endpoint": "192.168.4.14", + "allowed_ips": "10.1.1.0/24", + "dns": "1.1.1.1", + "allowed_groups": [], + "mfa_enabled": false, + "keepalive_interval": 25, + "peer_disconnect_threshold": 180 + })) + .send() + .await; +} diff --git a/tests/enterprise_settings.rs b/tests/enterprise_settings.rs index 66d63f5c3..78fb11496 100644 --- a/tests/enterprise_settings.rs +++ b/tests/enterprise_settings.rs @@ -1,5 +1,6 @@ mod common; +use common::exceed_enterprise_limits; use defguard::{ enterprise::{ db::models::enterprise_settings::EnterpriseSettings, @@ -35,6 +36,8 @@ async fn test_only_enterprise_can_modify() { let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); + exceed_enterprise_limits(&client).await; + // unset the license let license = get_cached_license().clone(); set_cached_license(None); @@ -75,6 +78,8 @@ async fn test_admin_devices_management_is_enforced() { let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); + exceed_enterprise_limits(&client).await; + // create network let response = client .post("/api/v1/network") @@ -152,6 +157,8 @@ async fn test_regular_user_device_management() { let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); + exceed_enterprise_limits(&client).await; + // create network let response = client .post("/api/v1/network") diff --git a/tests/openid_login.rs b/tests/openid_login.rs index beb62c7e5..7c9d8def1 100644 --- a/tests/openid_login.rs +++ b/tests/openid_login.rs @@ -1,4 +1,5 @@ use chrono::{Duration, Utc}; +use common::exceed_enterprise_limits; use defguard::{ config::DefGuardConfig, enterprise::{ @@ -33,6 +34,8 @@ async fn test_openid_providers() { let response = client.post("/api/v1/auth").json(&auth).send().await; assert_eq!(response.status(), StatusCode::OK); + exceed_enterprise_limits(&client).await; + let provider_data = AddProviderData::new( "test", "https://accounts.google.com", diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index ed010dc60..cf73b36bd 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -1089,6 +1089,8 @@ const en: BaseTranslation = { licenseInfo: { title: 'License information', noLicense: 'No license', + licenseNotRequired: + "

You have access to this enterprise feature, as you haven't exceeded any of the usage limits yet. Check the documentation for more information.

", types: { subscription: { label: 'Subscription', diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index c70aa7afe..def7dc3d1 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -2649,6 +2649,10 @@ type RootTranslation = { * N​o​ ​l​i​c​e​n​s​e */ noLicense: string + /** + * <​p​>​Y​o​u​ ​h​a​v​e​ ​a​c​c​e​s​s​ ​t​o​ ​t​h​i​s​ ​e​n​t​e​r​p​r​i​s​e​ ​f​e​a​t​u​r​e​,​ ​a​s​ ​y​o​u​ ​h​a​v​e​n​'​t​ ​e​x​c​e​e​d​e​d​ ​a​n​y​ ​o​f​ ​t​h​e​ ​u​s​a​g​e​ ​l​i​m​i​t​s​ ​y​e​t​.​ ​C​h​e​c​k​ ​t​h​e​ ​<​a​ ​h​r​e​f​=​'​h​t​t​p​s​:​/​/​d​o​c​s​.​d​e​f​g​u​a​r​d​.​n​e​t​/​e​n​t​e​r​p​r​i​s​e​/​l​i​c​e​n​s​e​'​>​d​o​c​u​m​e​n​t​a​t​i​o​n​<​/​a​>​ ​f​o​r​ ​m​o​r​e​ ​i​n​f​o​r​m​a​t​i​o​n​.​<​/​p​> + */ + licenseNotRequired: string types: { subscription: { /** @@ -6903,6 +6907,10 @@ export type TranslationFunctions = { * No license */ noLicense: () => LocalizedString + /** + *

You have access to this enterprise feature, as you haven't exceeded any of the usage limits yet. Check the documentation for more information.

+ */ + licenseNotRequired: () => LocalizedString types: { subscription: { /** diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index b367c0c91..84d4e03d1 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -1078,6 +1078,8 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe licenseInfo: { title: 'Informacje o licencji', noLicense: 'Brak licencji', + licenseNotRequired: + "

Posiadasz dostęp do tej funkcji enterprise, ponieważ nie przekroczyłeś jeszcze żadnych limitów. Sprawdź dokumentację, aby uzyskać więcej informacji.

", types: { subscription: { label: 'Subskrypcja', diff --git a/web/src/pages/settings/components/EnterpriseSettings/EnterpriseSettings.tsx b/web/src/pages/settings/components/EnterpriseSettings/EnterpriseSettings.tsx index af50d0056..dda18754c 100644 --- a/web/src/pages/settings/components/EnterpriseSettings/EnterpriseSettings.tsx +++ b/web/src/pages/settings/components/EnterpriseSettings/EnterpriseSettings.tsx @@ -1,4 +1,7 @@ +import parse from 'html-react-parser'; + import { useI18nContext } from '../../../../i18n/i18n-react'; +import { BigInfoBox } from '../../../../shared/defguard-ui/components/Layout/BigInfoBox/BigInfoBox'; import { useAppStore } from '../../../../shared/hooks/store/useAppStore'; import { EnterpriseForm } from './components/EnterpriseForm'; @@ -26,6 +29,13 @@ export const EnterpriseSettings = () => { )} + {!enterpriseStatus?.needs_license && !enterpriseStatus?.license_info && ( +
+ +
+ )}
diff --git a/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/LicenseSettings.tsx b/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/LicenseSettings.tsx index da16ed7af..2c06f99e0 100644 --- a/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/LicenseSettings.tsx +++ b/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/LicenseSettings.tsx @@ -187,7 +187,9 @@ export const LicenseSettings = () => { ) : ( -

{LL.settingsPage.license.licenseInfo.noLicense()}

+ <> +

{LL.settingsPage.license.licenseInfo.noLicense()}

+ )} diff --git a/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/styles.scss b/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/styles.scss index f874bb076..a14d6b109 100644 --- a/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/styles.scss +++ b/web/src/pages/settings/components/GlobalSettings/components/LicenseSettings/styles.scss @@ -44,3 +44,7 @@ #no-license { text-align: center; } + +#license-not-required { + text-align: center; +} diff --git a/web/src/pages/settings/components/OpenIdSettings/OpenIdSettings.tsx b/web/src/pages/settings/components/OpenIdSettings/OpenIdSettings.tsx index 8ba04f5bf..37a3f1a0f 100644 --- a/web/src/pages/settings/components/OpenIdSettings/OpenIdSettings.tsx +++ b/web/src/pages/settings/components/OpenIdSettings/OpenIdSettings.tsx @@ -1,6 +1,9 @@ import './style.scss'; +import parse from 'html-react-parser'; + import { useI18nContext } from '../../../../i18n/i18n-react'; +import { BigInfoBox } from '../../../../shared/defguard-ui/components/Layout/BigInfoBox/BigInfoBox'; import { useAppStore } from '../../../../shared/hooks/store/useAppStore'; import { OpenIdGeneralSettings } from './components/OpenIdGeneralSettings'; import { OpenIdSettingsForm } from './components/OpenIdSettingsForm'; @@ -30,6 +33,13 @@ export const OpenIdSettings = () => { )} + {!enterpriseStatus?.needs_license && !enterpriseStatus?.license_info && ( +
+ +
+ )}
diff --git a/web/src/pages/settings/style.scss b/web/src/pages/settings/style.scss index a76e77987..7677dc130 100644 --- a/web/src/pages/settings/style.scss +++ b/web/src/pages/settings/style.scss @@ -105,7 +105,6 @@ & > .left, & > .right { - grid-row: 1; width: 100%; max-width: 750px; display: flex; @@ -114,4 +113,9 @@ } } } + + .license-not-required-container { + grid-column: 1 / -1; + width: 100%; + } } diff --git a/web/src/pages/users/UsersOverview/components/UsersList/components/UsersListGroups.tsx b/web/src/pages/users/UsersOverview/components/UsersList/components/UsersListGroups.tsx index 8f49fd9d4..7ec85a908 100644 --- a/web/src/pages/users/UsersOverview/components/UsersList/components/UsersListGroups.tsx +++ b/web/src/pages/users/UsersOverview/components/UsersList/components/UsersListGroups.tsx @@ -88,7 +88,7 @@ export const UsersListGroups = ({ groups }: Props) => { > {displayGroups.map((g, index) => (
- +
))} {enabledModal && ( diff --git a/web/src/shared/defguard-ui b/web/src/shared/defguard-ui index 52a2f6d9b..b61bef8c8 160000 --- a/web/src/shared/defguard-ui +++ b/web/src/shared/defguard-ui @@ -1 +1 @@ -Subproject commit 52a2f6d9bf70d5cb497467f1caf4aa7a36d5d910 +Subproject commit b61bef8c893b4a27f62a3463d847274591520398 diff --git a/web/src/shared/types.ts b/web/src/shared/types.ts index fd02238e0..384a4c824 100644 --- a/web/src/shared/types.ts +++ b/web/src/shared/types.ts @@ -884,6 +884,7 @@ export type EnterpriseStatus = { enabled: boolean; // If there is no license, there is no license info license_info?: LicenseInfo; + needs_license: boolean; }; export interface Webhook { From 8baf58ef86ff32c577b0ca4c2a7c2051e350d93b Mon Sep 17 00:00:00 2001 From: Aleksander <170264518+t-aleksander@users.noreply.github.com> Date: Wed, 13 Nov 2024 18:52:48 +0100 Subject: [PATCH 3/8] Fix network setup wizard IP input / e2e tests (#854) * fix network wizard * fix tests --- .../WizardNetworkConfiguration.tsx | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx index cce80f6d5..7c7f734ab 100644 --- a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx +++ b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx @@ -2,6 +2,7 @@ import './style.scss'; import { zodResolver } from '@hookform/resolvers/zod'; import { useMutation, useQuery } from '@tanstack/react-query'; +import ipaddr from 'ipaddr.js'; import { useEffect, useMemo, useRef, useState } from 'react'; import { SubmitHandler, useForm } from 'react-hook-form'; import { z } from 'zod'; @@ -20,7 +21,7 @@ import { QueryKeys } from '../../../../shared/queries'; import { ModifyNetworkRequest } from '../../../../shared/types'; import { titleCase } from '../../../../shared/utils/titleCase'; import { trimObjectStrings } from '../../../../shared/utils/trimObjectStrings.ts'; -import { validateIpOrDomainList, validateIPv4 } from '../../../../shared/validators'; +import { validateIpOrDomainList } from '../../../../shared/validators'; import { useWizardStore } from '../../hooks/useWizardStore'; type FormInputs = ModifyNetworkRequest['network']; @@ -91,13 +92,36 @@ export const WizardNetworkConfiguration = () => { if (!netmaskPresent) { return false; } - const ipValid = validateIPv4(value, true); - if (ipValid) { - const host = value.split('.')[3].split('/')[0]; - if (host === '0') return false; + const ipValid = ipaddr.isValidCIDR(value); + if (!ipValid) { + return false; + } + const [address] = ipaddr.parseCIDR(value); + if (address.kind() === 'ipv6') { + const networkAddress = ipaddr.IPv6.networkAddressFromCIDR(value); + const broadcastAddress = ipaddr.IPv6.broadcastAddressFromCIDR(value); + if ( + (address as ipaddr.IPv6).toNormalizedString() === + networkAddress.toNormalizedString() || + (address as ipaddr.IPv6).toNormalizedString() === + broadcastAddress.toNormalizedString() + ) { + return false; + } + } else { + const networkAddress = ipaddr.IPv4.networkAddressFromCIDR(value); + const broadcastAddress = ipaddr.IPv4.broadcastAddressFromCIDR(value); + if ( + (address as ipaddr.IPv4).toNormalizedString() === + networkAddress.toNormalizedString() || + (address as ipaddr.IPv4).toNormalizedString() === + broadcastAddress.toNormalizedString() + ) { + return false; + } } return ipValid; - }), + }, LL.form.error.addressNetmask()), endpoint: z.string().min(1, LL.form.error.required()), port: z .number({ From 57d19ec435599aa7be2cb98a668ee0d7fa5f91d9 Mon Sep 17 00:00:00 2001 From: Adam Date: Fri, 15 Nov 2024 19:37:46 +0100 Subject: [PATCH 4/8] OpenID via Proxy (#845) * Cleanup * Shape AuthInfo RPC * Return AuthInfoResponse * Handle AuthInfoRequest * Get ready for AuthCallback * Now User-Agent header is required to authenticate * Put UserAgentParser behind LazyLock * Fold common code * Fold common logic into user_from_claims() * Return empty response for AuthCallback * Add migration to make openid_sub unique * deny auth info when not enterprise * disallow disabled users * log an error if id_token or state is missing * Switch to authorization code flow * Handle "aud" in token claims * add oidc button display name, pass token and url to proxy * update protobufs * Merge last migrations * preserve full error * login -> sign in, change flow to auth code in proxy * add useeffect dependency * cleanup * log error details * sqlx prepare * add display name to test * remove user agent parser --------- Co-authored-by: Aleksander <170264518+t-aleksander@users.noreply.github.com> --- ...c7e64b747ea250e6e27a5177f89116cb4775.json} | 12 +- ...2c0315892c7d2c012c7cc4c399410189814b.json} | 5 +- ...d798082c152c7c629190b6c6edd3f113c544.json} | 12 +- ...b8a0a34a7376b86242814bda2f8e004d1589.json} | 12 +- ...00053846b927c80463b3fd836ea0af11cf07.json} | 5 +- ...601bc4b08439d2516802d06dfd5da3692fd3.json} | 5 +- ...748410f91dd5775508cf403b88eabf26c515.json} | 4 +- ...44eb5e7cd0c3bdbb808b7898997328c91a1d.json} | 12 +- Cargo.toml | 2 +- .../20241112105513_openid_sub_unique.down.sql | 2 + .../20241112105513_openid_sub_unique.up.sql | 4 + proto | 2 +- src/appstate.rs | 4 - src/bin/defguard.rs | 6 +- src/config.rs | 34 +- src/db/models/device_login.rs | 2 +- src/db/models/enrollment.rs | 29 +- src/db/models/user.rs | 24 +- src/enterprise/db/models/openid_provider.rs | 19 +- src/enterprise/grpc/polling.rs | 2 +- src/enterprise/handlers/openid_login.rs | 654 +++++++++--------- src/enterprise/handlers/openid_providers.rs | 11 +- src/grpc/enrollment.rs | 14 +- src/grpc/mod.rs | 121 +++- src/grpc/password_reset.rs | 2 +- src/handlers/auth.rs | 232 ++++--- src/handlers/mod.rs | 10 +- src/handlers/user.rs | 2 +- src/headers.rs | 70 +- src/lib.rs | 5 - tests/auth.rs | 14 +- tests/common/client.rs | 6 +- tests/common/mod.rs | 4 - tests/openid.rs | 16 +- tests/openid_login.rs | 10 +- web/src/i18n/en/index.ts | 6 + web/src/i18n/i18n-types.ts | 28 + web/src/i18n/pl/index.ts | 6 + web/src/pages/auth/Callback/Callback.tsx | 16 +- web/src/pages/auth/Login/Login.tsx | 5 +- .../auth/Login/components/OidcButtons.tsx | 22 +- .../components/OpenIdSettingsForm.tsx | 28 +- web/src/shared/types.ts | 4 +- 43 files changed, 875 insertions(+), 608 deletions(-) rename .sqlx/{query-dea2f3d8b9508ef1df84e204816a9cdc53103547d3d273803313bc091c72c323.json => query-0d2c77d57bab4410b7cf7a79bcadc7e64b747ea250e6e27a5177f89116cb4775.json} (68%) rename .sqlx/{query-a0700f9701a61fc64af20165feb4627ef6f053bcb59d81e41dd1cf1bd199b543.json => query-24b61173ea347abd382b1839446f2c0315892c7d2c012c7cc4c399410189814b.json} (63%) rename .sqlx/{query-ca74174177efd38a84835e809b91a7b39b9389ae01437b7d9405ffa074388279.json => query-45953e3820712bf534e2696f6e50d798082c152c7c629190b6c6edd3f113c544.json} (74%) rename .sqlx/{query-dafe0c3d80ed8e09771cf910d6a7696bb16daaece311438358321c8e8ea3b65f.json => query-4ad2544f4b65e4c037f8b574ae50b8a0a34a7376b86242814bda2f8e004d1589.json} (73%) rename .sqlx/{query-a6fe220739875c6894e1a678d4c24de34b7c3ca8bd5e48ed9f313d20cbe8f8e4.json => query-a79fb5b30b7366e7145c147458bd00053846b927c80463b3fd836ea0af11cf07.json} (67%) rename .sqlx/{query-6ae26abc026e5fe59e8505b2563fca6caf4a42bb9ad01d6f44db41a572abce95.json => query-ab4d3df8aa0e1401824c5ed291a0601bc4b08439d2516802d06dfd5da3692fd3.json} (60%) rename .sqlx/{query-b0d21b63dc414e3738a85e9e5ed8a71cbb3019f86f5d69e186efba7f0fabd4d1.json => query-b2cfe32ecd399a152cc3f9ac5a23748410f91dd5775508cf403b88eabf26c515.json} (96%) rename .sqlx/{query-72958cb69e1b737ca4a3eb2915bbf8bde2ff8f3db10090cef071aeb932c99ab2.json => query-fedc6441d5b2371ac365cca09a9b44eb5e7cd0c3bdbb808b7898997328c91a1d.json} (68%) create mode 100644 migrations/20241112105513_openid_sub_unique.down.sql create mode 100644 migrations/20241112105513_openid_sub_unique.up.sql diff --git a/.sqlx/query-dea2f3d8b9508ef1df84e204816a9cdc53103547d3d273803313bc091c72c323.json b/.sqlx/query-0d2c77d57bab4410b7cf7a79bcadc7e64b747ea250e6e27a5177f89116cb4775.json similarity index 68% rename from .sqlx/query-dea2f3d8b9508ef1df84e204816a9cdc53103547d3d273803313bc091c72c323.json rename to .sqlx/query-0d2c77d57bab4410b7cf7a79bcadc7e64b747ea250e6e27a5177f89116cb4775.json index 83298bd6b..6fde8ac13 100644 --- a/.sqlx/query-dea2f3d8b9508ef1df84e204816a9cdc53103547d3d273803313bc091c72c323.json +++ b/.sqlx/query-0d2c77d57bab4410b7cf7a79bcadc7e64b747ea250e6e27a5177f89116cb4775.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, name, base_url, client_id, client_secret FROM openidprovider", + "query": "SELECT id, name, base_url, client_id, client_secret, display_name FROM openidprovider LIMIT 1", "describe": { "columns": [ { @@ -27,6 +27,11 @@ "ordinal": 4, "name": "client_secret", "type_info": "Text" + }, + { + "ordinal": 5, + "name": "display_name", + "type_info": "Text" } ], "parameters": { @@ -37,8 +42,9 @@ false, false, false, - false + false, + true ] }, - "hash": "dea2f3d8b9508ef1df84e204816a9cdc53103547d3d273803313bc091c72c323" + "hash": "0d2c77d57bab4410b7cf7a79bcadc7e64b747ea250e6e27a5177f89116cb4775" } diff --git a/.sqlx/query-a0700f9701a61fc64af20165feb4627ef6f053bcb59d81e41dd1cf1bd199b543.json b/.sqlx/query-24b61173ea347abd382b1839446f2c0315892c7d2c012c7cc4c399410189814b.json similarity index 63% rename from .sqlx/query-a0700f9701a61fc64af20165feb4627ef6f053bcb59d81e41dd1cf1bd199b543.json rename to .sqlx/query-24b61173ea347abd382b1839446f2c0315892c7d2c012c7cc4c399410189814b.json index 94ec29743..ae65d7c0d 100644 --- a/.sqlx/query-a0700f9701a61fc64af20165feb4627ef6f053bcb59d81e41dd1cf1bd199b543.json +++ b/.sqlx/query-24b61173ea347abd382b1839446f2c0315892c7d2c012c7cc4c399410189814b.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE openidprovider SET name = $1, base_url = $2, client_id = $3, client_secret = $4 WHERE id = $5", + "query": "UPDATE openidprovider SET name = $1, base_url = $2, client_id = $3, client_secret = $4, display_name = $5 WHERE id = $6", "describe": { "columns": [], "parameters": { @@ -9,10 +9,11 @@ "Text", "Text", "Text", + "Text", "Int8" ] }, "nullable": [] }, - "hash": "a0700f9701a61fc64af20165feb4627ef6f053bcb59d81e41dd1cf1bd199b543" + "hash": "24b61173ea347abd382b1839446f2c0315892c7d2c012c7cc4c399410189814b" } diff --git a/.sqlx/query-ca74174177efd38a84835e809b91a7b39b9389ae01437b7d9405ffa074388279.json b/.sqlx/query-45953e3820712bf534e2696f6e50d798082c152c7c629190b6c6edd3f113c544.json similarity index 74% rename from .sqlx/query-ca74174177efd38a84835e809b91a7b39b9389ae01437b7d9405ffa074388279.json rename to .sqlx/query-45953e3820712bf534e2696f6e50d798082c152c7c629190b6c6edd3f113c544.json index bc0e41c84..075d69c6b 100644 --- a/.sqlx/query-ca74174177efd38a84835e809b91a7b39b9389ae01437b7d9405ffa074388279.json +++ b/.sqlx/query-45953e3820712bf534e2696f6e50d798082c152c7c629190b6c6edd3f113c544.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\" FROM \"openidprovider\"", + "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\" FROM \"openidprovider\"", "describe": { "columns": [ { @@ -27,6 +27,11 @@ "ordinal": 4, "name": "client_secret", "type_info": "Text" + }, + { + "ordinal": 5, + "name": "display_name", + "type_info": "Text" } ], "parameters": { @@ -37,8 +42,9 @@ false, false, false, - false + false, + true ] }, - "hash": "ca74174177efd38a84835e809b91a7b39b9389ae01437b7d9405ffa074388279" + "hash": "45953e3820712bf534e2696f6e50d798082c152c7c629190b6c6edd3f113c544" } diff --git a/.sqlx/query-dafe0c3d80ed8e09771cf910d6a7696bb16daaece311438358321c8e8ea3b65f.json b/.sqlx/query-4ad2544f4b65e4c037f8b574ae50b8a0a34a7376b86242814bda2f8e004d1589.json similarity index 73% rename from .sqlx/query-dafe0c3d80ed8e09771cf910d6a7696bb16daaece311438358321c8e8ea3b65f.json rename to .sqlx/query-4ad2544f4b65e4c037f8b574ae50b8a0a34a7376b86242814bda2f8e004d1589.json index 461929642..ae2f3d7e9 100644 --- a/.sqlx/query-dafe0c3d80ed8e09771cf910d6a7696bb16daaece311438358321c8e8ea3b65f.json +++ b/.sqlx/query-4ad2544f4b65e4c037f8b574ae50b8a0a34a7376b86242814bda2f8e004d1589.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\" FROM \"openidprovider\" WHERE id = $1", + "query": "SELECT id, \"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\" FROM \"openidprovider\" WHERE id = $1", "describe": { "columns": [ { @@ -27,6 +27,11 @@ "ordinal": 4, "name": "client_secret", "type_info": "Text" + }, + { + "ordinal": 5, + "name": "display_name", + "type_info": "Text" } ], "parameters": { @@ -39,8 +44,9 @@ false, false, false, - false + false, + true ] }, - "hash": "dafe0c3d80ed8e09771cf910d6a7696bb16daaece311438358321c8e8ea3b65f" + "hash": "4ad2544f4b65e4c037f8b574ae50b8a0a34a7376b86242814bda2f8e004d1589" } diff --git a/.sqlx/query-a6fe220739875c6894e1a678d4c24de34b7c3ca8bd5e48ed9f313d20cbe8f8e4.json b/.sqlx/query-a79fb5b30b7366e7145c147458bd00053846b927c80463b3fd836ea0af11cf07.json similarity index 67% rename from .sqlx/query-a6fe220739875c6894e1a678d4c24de34b7c3ca8bd5e48ed9f313d20cbe8f8e4.json rename to .sqlx/query-a79fb5b30b7366e7145c147458bd00053846b927c80463b3fd836ea0af11cf07.json index e32f0cf87..52b9b0b16 100644 --- a/.sqlx/query-a6fe220739875c6894e1a678d4c24de34b7c3ca8bd5e48ed9f313d20cbe8f8e4.json +++ b/.sqlx/query-a79fb5b30b7366e7145c147458bd00053846b927c80463b3fd836ea0af11cf07.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO \"openidprovider\" (\"name\",\"base_url\",\"client_id\",\"client_secret\") VALUES ($1,$2,$3,$4) RETURNING id", + "query": "INSERT INTO \"openidprovider\" (\"name\",\"base_url\",\"client_id\",\"client_secret\",\"display_name\") VALUES ($1,$2,$3,$4,$5) RETURNING id", "describe": { "columns": [ { @@ -14,6 +14,7 @@ "Text", "Text", "Text", + "Text", "Text" ] }, @@ -21,5 +22,5 @@ false ] }, - "hash": "a6fe220739875c6894e1a678d4c24de34b7c3ca8bd5e48ed9f313d20cbe8f8e4" + "hash": "a79fb5b30b7366e7145c147458bd00053846b927c80463b3fd836ea0af11cf07" } diff --git a/.sqlx/query-6ae26abc026e5fe59e8505b2563fca6caf4a42bb9ad01d6f44db41a572abce95.json b/.sqlx/query-ab4d3df8aa0e1401824c5ed291a0601bc4b08439d2516802d06dfd5da3692fd3.json similarity index 60% rename from .sqlx/query-6ae26abc026e5fe59e8505b2563fca6caf4a42bb9ad01d6f44db41a572abce95.json rename to .sqlx/query-ab4d3df8aa0e1401824c5ed291a0601bc4b08439d2516802d06dfd5da3692fd3.json index cc029dcf7..ee841e1f8 100644 --- a/.sqlx/query-6ae26abc026e5fe59e8505b2563fca6caf4a42bb9ad01d6f44db41a572abce95.json +++ b/.sqlx/query-ab4d3df8aa0e1401824c5ed291a0601bc4b08439d2516802d06dfd5da3692fd3.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE \"openidprovider\" SET \"name\" = $2,\"base_url\" = $3,\"client_id\" = $4,\"client_secret\" = $5 WHERE id = $1", + "query": "UPDATE \"openidprovider\" SET \"name\" = $2,\"base_url\" = $3,\"client_id\" = $4,\"client_secret\" = $5,\"display_name\" = $6 WHERE id = $1", "describe": { "columns": [], "parameters": { @@ -9,10 +9,11 @@ "Text", "Text", "Text", + "Text", "Text" ] }, "nullable": [] }, - "hash": "6ae26abc026e5fe59e8505b2563fca6caf4a42bb9ad01d6f44db41a572abce95" + "hash": "ab4d3df8aa0e1401824c5ed291a0601bc4b08439d2516802d06dfd5da3692fd3" } diff --git a/.sqlx/query-b0d21b63dc414e3738a85e9e5ed8a71cbb3019f86f5d69e186efba7f0fabd4d1.json b/.sqlx/query-b2cfe32ecd399a152cc3f9ac5a23748410f91dd5775508cf403b88eabf26c515.json similarity index 96% rename from .sqlx/query-b0d21b63dc414e3738a85e9e5ed8a71cbb3019f86f5d69e186efba7f0fabd4d1.json rename to .sqlx/query-b2cfe32ecd399a152cc3f9ac5a23748410f91dd5775508cf403b88eabf26c515.json index 84f36c6df..1012d51ae 100644 --- a/.sqlx/query-b0d21b63dc414e3738a85e9e5ed8a71cbb3019f86f5d69e186efba7f0fabd4d1.json +++ b/.sqlx/query-b2cfe32ecd399a152cc3f9ac5a23748410f91dd5775508cf403b88eabf26c515.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, username, password_hash, last_name, first_name, email, phone, mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub FROM \"user\" WHERE openid_sub = $1", + "query": "SELECT id, username, password_hash, last_name, first_name, email, phone, mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub FROM \"user\" WHERE openid_sub = $1 LIMIT 1", "describe": { "columns": [ { @@ -121,5 +121,5 @@ true ] }, - "hash": "b0d21b63dc414e3738a85e9e5ed8a71cbb3019f86f5d69e186efba7f0fabd4d1" + "hash": "b2cfe32ecd399a152cc3f9ac5a23748410f91dd5775508cf403b88eabf26c515" } diff --git a/.sqlx/query-72958cb69e1b737ca4a3eb2915bbf8bde2ff8f3db10090cef071aeb932c99ab2.json b/.sqlx/query-fedc6441d5b2371ac365cca09a9b44eb5e7cd0c3bdbb808b7898997328c91a1d.json similarity index 68% rename from .sqlx/query-72958cb69e1b737ca4a3eb2915bbf8bde2ff8f3db10090cef071aeb932c99ab2.json rename to .sqlx/query-fedc6441d5b2371ac365cca09a9b44eb5e7cd0c3bdbb808b7898997328c91a1d.json index 46c6f250e..ea086a87e 100644 --- a/.sqlx/query-72958cb69e1b737ca4a3eb2915bbf8bde2ff8f3db10090cef071aeb932c99ab2.json +++ b/.sqlx/query-fedc6441d5b2371ac365cca09a9b44eb5e7cd0c3bdbb808b7898997328c91a1d.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, name, base_url, client_id, client_secret FROM openidprovider WHERE name = $1", + "query": "SELECT id, name, base_url, client_id, client_secret, display_name FROM openidprovider WHERE name = $1", "describe": { "columns": [ { @@ -27,6 +27,11 @@ "ordinal": 4, "name": "client_secret", "type_info": "Text" + }, + { + "ordinal": 5, + "name": "display_name", + "type_info": "Text" } ], "parameters": { @@ -39,8 +44,9 @@ false, false, false, - false + false, + true ] }, - "hash": "72958cb69e1b737ca4a3eb2915bbf8bde2ff8f3db10090cef071aeb932c99ab2" + "hash": "fedc6441d5b2371ac365cca09a9b44eb5e7cd0c3bdbb808b7898997328c91a1d" } diff --git a/Cargo.toml b/Cargo.toml index 7e53e12ec..6ef29f19c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" license = "Apache-2.0" homepage = "https://defguard.net/" repository = "https://github.com/DefGuard/defguard" -rust-version = "1.76" +rust-version = "1.80" [workspace] diff --git a/migrations/20241112105513_openid_sub_unique.down.sql b/migrations/20241112105513_openid_sub_unique.down.sql new file mode 100644 index 000000000..bd2cf690a --- /dev/null +++ b/migrations/20241112105513_openid_sub_unique.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE openidprovider DROP COLUMN display_name; +ALTER TABLE "user" DROP CONSTRAINT "user_openid_sub_key"; diff --git a/migrations/20241112105513_openid_sub_unique.up.sql b/migrations/20241112105513_openid_sub_unique.up.sql new file mode 100644 index 000000000..e2fc2bfdf --- /dev/null +++ b/migrations/20241112105513_openid_sub_unique.up.sql @@ -0,0 +1,4 @@ +-- Make openid_sub unique. +-- This migration may fail if duplicate openid_subs exist in the database. +ALTER TABLE "user" ADD CONSTRAINT "user_openid_sub_key" UNIQUE (openid_sub); +ALTER TABLE openidprovider ADD COLUMN display_name TEXT DEFAULT NULL; diff --git a/proto b/proto index 8309982b9..b9adb0bc8 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 8309982b94e82a7cbe39dd529967f43e49b3ef1d +Subproject commit b9adb0bc87228c88c42f144caa47c8b69a3fb98c diff --git a/src/appstate.rs b/src/appstate.rs index 78556a756..713ed78c6 100644 --- a/src/appstate.rs +++ b/src/appstate.rs @@ -13,7 +13,6 @@ use tokio::{ }, task::spawn, }; -use uaparser::UserAgentParser; use webauthn_rs::prelude::*; use crate::{ @@ -30,7 +29,6 @@ pub struct AppState { wireguard_tx: Sender, pub mail_tx: UnboundedSender, pub webauthn: Arc, - pub user_agent_parser: Arc, pub failed_logins: Arc>, key: Key, } @@ -103,7 +101,6 @@ impl AppState { rx: UnboundedReceiver, wireguard_tx: Sender, mail_tx: UnboundedSender, - user_agent_parser: Arc, failed_logins: Arc>, ) -> Self { spawn(Self::handle_triggers(pool.clone(), rx)); @@ -131,7 +128,6 @@ impl AppState { wireguard_tx, mail_tx, webauthn, - user_agent_parser, failed_logins, key, } diff --git a/src/bin/defguard.rs b/src/bin/defguard.rs index ad7565d74..a52be806f 100644 --- a/src/bin/defguard.rs +++ b/src/bin/defguard.rs @@ -12,7 +12,6 @@ use defguard::{ limits::update_counts, }, grpc::{run_grpc_bidi_stream, run_grpc_server, GatewayMap, WorkerState}, - headers::create_user_agent_parser, init_dev_env, init_vpn_location, mail::{run_mail_handler, Mail}, run_web_server, @@ -82,7 +81,6 @@ async fn main() -> Result<(), anyhow::Error> { let (mail_tx, mail_rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(webhook_tx.clone()))); let gateway_state = Arc::new(Mutex::new(GatewayMap::new())); - let user_agent_parser = create_user_agent_parser(); // initialize admin user User::init_admin_user(&pool, config.default_admin_password.expose_secret()).await?; @@ -119,9 +117,9 @@ async fn main() -> Result<(), anyhow::Error> { // run services tokio::select! { - res = run_grpc_bidi_stream(pool.clone(), wireguard_tx.clone(), mail_tx.clone(), user_agent_parser.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"), + res = run_grpc_bidi_stream(pool.clone(), wireguard_tx.clone(), mail_tx.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"), res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_state), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"), - res = run_web_server(worker_state, gateway_state, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), user_agent_parser, failed_logins) => error!("Web server returned early: {res:#?}"), + res = run_web_server(worker_state, gateway_state, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), failed_logins) => error!("Web server returned early: {res:#?}"), res = run_mail_handler(mail_rx, pool.clone()) => error!("Mail handler returned early: {res:#?}"), res = run_periodic_peer_disconnect(pool.clone(), wireguard_tx) => error!("Periodic peer disconnect task returned early: {res:#?}"), res = run_periodic_stats_purge(pool.clone(), config.stats_purge_frequency.into(), config.stats_purge_threshold.into()), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:#?}"), diff --git a/src/config.rs b/src/config.rs index 8e1b025ce..6bbe54981 100644 --- a/src/config.rs +++ b/src/config.rs @@ -286,6 +286,17 @@ impl DefGuardConfig { None } } + + /// Returns configured URL with "auth/callback" appended to the path. + #[must_use] + pub(crate) fn callback_url(&self) -> Url { + let mut url = self.url.clone(); + // Append "auth/callback" to the URL. + if let Ok(mut path_segments) = url.path_segments_mut() { + path_segments.extend(&["auth", "callback"]); + } + url + } } impl Default for DefGuardConfig { @@ -308,10 +319,7 @@ mod tests { #[test] fn test_generate_rp_id() { - // unset variables - env::remove_var("DEFGUARD_URL"); env::remove_var("DEFGUARD_WEBAUTHN_RP_ID"); - env::set_var("DEFGUARD_URL", "https://defguard.example.com"); let config = DefGuardConfig::new(); @@ -330,10 +338,7 @@ mod tests { #[test] fn test_generate_cookie_domain() { - // unset variables - env::remove_var("DEFGUARD_URL"); env::remove_var("DEFGUARD_COOKIE_DOMAIN"); - env::set_var("DEFGUARD_URL", "https://defguard.example.com"); let config = DefGuardConfig::new(); @@ -349,4 +354,21 @@ mod tests { assert_eq!(config.cookie_domain, Some("example.com".to_string())); } + + #[test] + fn test_callback_url() { + env::set_var("DEFGUARD_URL", "https://defguard.example.com"); + let config = DefGuardConfig::new(); + assert_eq!( + config.callback_url().as_str(), + "https://defguard.example.com/auth/callback" + ); + + env::set_var("DEFGUARD_URL", "https://defguard.example.com:8443/path"); + let config = DefGuardConfig::new(); + assert_eq!( + config.callback_url().as_str(), + "https://defguard.example.com:8443/path/auth/callback" + ); + } } diff --git a/src/db/models/device_login.rs b/src/db/models/device_login.rs index 5cbc371b8..f94dc1a35 100644 --- a/src/db/models/device_login.rs +++ b/src/db/models/device_login.rs @@ -72,7 +72,7 @@ impl DeviceLoginEvent { } } - pub async fn find_device_login_event( + pub(crate) async fn find_device_login_event( &self, pool: &PgPool, ) -> Result>, SqlxError> { diff --git a/src/db/models/enrollment.rs b/src/db/models/enrollment.rs index c4b871adb..8555f6a7a 100644 --- a/src/db/models/enrollment.rs +++ b/src/db/models/enrollment.rs @@ -112,7 +112,10 @@ impl Token { } } - pub async fn save(&self, transaction: &mut PgConnection) -> Result<(), TokenError> { + pub async fn save<'e, E>(&self, executor: E) -> Result<(), TokenError> + where + E: PgExecutor<'e>, + { query!( "INSERT INTO token (id, user_id, admin_id, email, created_at, expires_at, used_at, token_type) \ VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", @@ -125,7 +128,7 @@ impl Token { self.used_at, self.token_type, ) - .execute(transaction) + .execute(executor) .await?; Ok(()) } @@ -253,10 +256,13 @@ impl Token { Ok(user) } - pub async fn delete_unused_user_tokens( - transaction: &mut PgConnection, + pub async fn delete_unused_user_tokens<'e, E>( + executor: E, user_id: Id, - ) -> Result<(), TokenError> { + ) -> Result<(), TokenError> + where + E: PgExecutor<'e>, + { debug!("Deleting unused tokens for the user."); let result = query!( "DELETE FROM token \ @@ -264,7 +270,7 @@ impl Token { AND used_at IS NULL", user_id ) - .execute(transaction) + .execute(executor) .await?; info!( "Deleted {} unused enrollment tokens for the user.", @@ -588,12 +594,15 @@ impl User { } // Remove unused tokens when triggering user enrollment - async fn clear_unused_enrollment_tokens( + pub(crate) async fn clear_unused_enrollment_tokens<'e, E>( &self, - transaction: &mut PgConnection, - ) -> Result<(), TokenError> { + executor: E, + ) -> Result<(), TokenError> + where + E: PgExecutor<'e>, + { info!("Removing unused tokens for user {}.", self.username); - Token::delete_unused_user_tokens(transaction, self.id).await + Token::delete_unused_user_tokens(executor, self.id).await } } diff --git a/src/db/models/user.rs b/src/db/models/user.rs index 21fee1ac1..77c76dbb1 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -80,6 +80,7 @@ pub struct User { pub is_active: bool, /// The user's sub claim returned by the OpenID provider. Also indicates whether the user has /// used OpenID to log in. + // FIXME: must be unique pub openid_sub: Option, // secret has been verified and TOTP can be used pub(crate) totp_enabled: bool, @@ -136,7 +137,7 @@ impl User { self.password_hash = hash_password(password).ok(); } - pub fn verify_password(&self, password: &str) -> Result<(), HashError> { + pub(crate) fn verify_password(&self, password: &str) -> Result<(), HashError> { match &self.password_hash { Some(hash) => { let parsed_hash = PasswordHash::new(hash)?; @@ -150,12 +151,12 @@ impl User { } #[must_use] - pub fn has_password(&self) -> bool { + pub(crate) fn has_password(&self) -> bool { self.password_hash.is_some() } #[must_use] - pub fn name(&self) -> String { + pub(crate) fn name(&self) -> String { format!("{} {}", self.first_name, self.last_name) } @@ -163,7 +164,7 @@ impl User { /// We assume the user is enrolled if they have a password set /// or they have logged in using an external OIDC. #[must_use] - pub fn is_enrolled(&self) -> bool { + pub(crate) fn is_enrolled(&self) -> bool { self.password_hash.is_some() || self.openid_sub.is_some() } } @@ -635,9 +636,9 @@ impl User { { query_as!( Self, - "SELECT id, username, password_hash, last_name, first_name, email, \ - phone, mfa_enabled, totp_enabled, email_mfa_enabled, \ - totp_secret, email_mfa_secret, mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub \ + "SELECT id, username, password_hash, last_name, first_name, email, phone, \ + mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub \ FROM \"user\" WHERE email = $1", email ) @@ -645,16 +646,17 @@ impl User { .await } + // FIXME: Remove `LIMIT 1` when `openid_sub` is unique. pub async fn find_by_sub<'e, E>(executor: E, sub: &str) -> Result, SqlxError> where E: PgExecutor<'e>, { query_as!( Self, - "SELECT id, username, password_hash, last_name, first_name, email, \ - phone, mfa_enabled, totp_enabled, email_mfa_enabled, \ - totp_secret, email_mfa_secret, mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub \ - FROM \"user\" WHERE openid_sub = $1", + "SELECT id, username, password_hash, last_name, first_name, email, phone, \ + mfa_enabled, totp_enabled, email_mfa_enabled, totp_secret, email_mfa_secret, \ + mfa_method \"mfa_method: _\", recovery_codes, is_active, openid_sub \ + FROM \"user\" WHERE openid_sub = $1 LIMIT 1", sub ) .fetch_optional(executor) diff --git a/src/enterprise/db/models/openid_provider.rs b/src/enterprise/db/models/openid_provider.rs index 099fd0f22..5861e7698 100644 --- a/src/enterprise/db/models/openid_provider.rs +++ b/src/enterprise/db/models/openid_provider.rs @@ -10,29 +10,38 @@ pub struct OpenIdProvider { pub base_url: String, pub client_id: String, pub client_secret: String, + pub display_name: Option, } impl OpenIdProvider { #[must_use] - pub fn new>(name: S, base_url: S, client_id: S, client_secret: S) -> Self { + pub fn new>( + name: S, + base_url: S, + client_id: S, + client_secret: S, + display_name: Option, + ) -> Self { Self { id: NoId, name: name.into(), base_url: base_url.into(), client_id: client_id.into(), client_secret: client_secret.into(), + display_name, } } pub async fn upsert(self, pool: &PgPool) -> Result, SqlxError> { if let Some(provider) = OpenIdProvider::::get_current(pool).await? { query!( - "UPDATE openidprovider SET name = $1, base_url = $2, client_id = $3, client_secret = $4 WHERE id = $5", + "UPDATE openidprovider SET name = $1, base_url = $2, client_id = $3, client_secret = $4, display_name = $5 WHERE id = $6", self.name, self.base_url, self.client_id, self.client_secret, - provider.id + self.display_name, + provider.id, ) .execute(pool) .await?; @@ -48,7 +57,7 @@ impl OpenIdProvider { pub async fn find_by_name(pool: &PgPool, name: &str) -> Result, SqlxError> { query_as!( OpenIdProvider, - "SELECT id, name, base_url, client_id, client_secret FROM openidprovider WHERE name = $1", + "SELECT id, name, base_url, client_id, client_secret, display_name FROM openidprovider WHERE name = $1", name ) .fetch_optional(pool) @@ -58,7 +67,7 @@ impl OpenIdProvider { pub async fn get_current(pool: &PgPool) -> Result, SqlxError> { query_as!( OpenIdProvider, - "SELECT id, name, base_url, client_id, client_secret FROM openidprovider" + "SELECT id, name, base_url, client_id, client_secret, display_name FROM openidprovider LIMIT 1" ) .fetch_optional(pool) .await diff --git a/src/enterprise/grpc/polling.rs b/src/enterprise/grpc/polling.rs index 3ea1c1e33..9c54a9865 100644 --- a/src/enterprise/grpc/polling.rs +++ b/src/enterprise/grpc/polling.rs @@ -41,7 +41,7 @@ impl PollingServer { }; // Polling tokens are valid indefinitely - info!("Token validation successful {token:?}."); + debug!("Token validation successful {token:?}."); Ok(token) } diff --git a/src/enterprise/handlers/openid_login.rs b/src/enterprise/handlers/openid_login.rs index d07150353..847e93c1b 100644 --- a/src/enterprise/handlers/openid_login.rs +++ b/src/enterprise/handlers/openid_login.rs @@ -9,141 +9,316 @@ use axum_extra::{ TypedHeader, }; use openidconnect::{ - core::{ - CoreClient, CoreGenderClaim, CoreJsonWebKeyType, CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseType, - }, + core::{CoreAuthenticationFlow, CoreClient, CoreProviderMetadata}, reqwest::async_http_client, - AuthenticationFlow, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, IdToken, - IssuerUrl, Nonce, ProviderMetadata, RedirectUrl, Scope, + AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, RedirectUrl, Scope, }; +use reqwest::Url; use serde_json::json; use sqlx::PgPool; use time::Duration; +const COOKIE_MAX_AGE: Duration = Duration::days(1); +static CSRF_COOKIE_NAME: &str = "csrf"; +static NONCE_COOKIE_NAME: &str = "nonce"; + use super::LicenseInfo; use crate::{ appstate::AppState, - db::{MFAInfo, Session, SessionState, Settings, User, UserInfo}, + db::{Id, Settings, User}, enterprise::db::models::openid_provider::OpenIdProvider, error::WebError, handlers::{ + auth::create_session, user::{check_username, prune_username}, ApiResponse, AuthResponse, SESSION_COOKIE_NAME, SIGN_IN_COOKIE_NAME, }, - headers::{check_new_device_login, get_user_agent_device, parse_user_agent}, server_config, }; -type ProvMeta = ProviderMetadata< - openidconnect::EmptyAdditionalProviderMetadata, - openidconnect::core::CoreAuthDisplay, - openidconnect::core::CoreClientAuthMethod, - openidconnect::core::CoreClaimName, - openidconnect::core::CoreClaimType, - openidconnect::core::CoreGrantType, - openidconnect::core::CoreJweContentEncryptionAlgorithm, - openidconnect::core::CoreJweKeyManagementAlgorithm, - openidconnect::core::CoreJwsSigningAlgorithm, - openidconnect::core::CoreJsonWebKeyType, - openidconnect::core::CoreJsonWebKeyUse, - openidconnect::core::CoreJsonWebKey, - openidconnect::core::CoreResponseMode, - openidconnect::core::CoreResponseType, - openidconnect::core::CoreSubjectIdentifierType, ->; - -async fn get_provider_metadata(url: &str) -> Result { +async fn get_provider_metadata(url: &str) -> Result { let issuer_url = IssuerUrl::new(url.to_string()).unwrap(); - // Discover the provider metadata based on a known base issuer URL // The url should be in the form of e.g. https://accounts.google.com // The url shouldn't contain a .well-known part, it will be added automatically - let Ok(provider_metadata) = - CoreProviderMetadata::discover_async(issuer_url, async_http_client).await - else { - return Err(WebError::Authorization(format!( - "Failed to discover provider metadata, make sure the providers' url is correct: {url}", - ))); - }; + match CoreProviderMetadata::discover_async(issuer_url, async_http_client).await { + Ok(provider_metadata) => Ok(provider_metadata), + Err(err) => { + Err(WebError::Authorization(format!( + "Failed to discover provider metadata, make sure the provider's URL is correct: {url}. Error details: {err:?}", + ))) + } + } +} - Ok(provider_metadata) +/// Build OpenID Connect client. +/// `url`: redirect/callback URL +pub(crate) async fn make_oidc_client( + url: Url, + provider: &OpenIdProvider, +) -> Result<(ClientId, CoreClient), WebError> { + let provider_metadata = get_provider_metadata(&provider.base_url).await?; + let client_id = ClientId::new(provider.client_id.to_string()); + let client_secret = ClientSecret::new(provider.client_secret.to_string()); + let core_client = CoreClient::from_provider_metadata( + provider_metadata, + client_id.clone(), + Some(client_secret), + ) + .set_redirect_uri(RedirectUrl::from_url(url)); + + Ok((client_id, core_client)) } -async fn make_oidc_client(pool: &PgPool) -> Result { +/// Get or create `User` from OpenID claims. +pub(crate) async fn user_from_claims( + pool: &PgPool, + nonce: Nonce, + code: AuthorizationCode, + callback_url: Url, +) -> Result, WebError> { let Some(provider) = OpenIdProvider::get_current(pool).await? else { return Err(WebError::ObjectNotFound( "OpenID provider not set".to_string(), )); }; - - let provider_metadata = get_provider_metadata(&provider.base_url).await?; - let client_id = ClientId::new(provider.client_id); - let client_secret = ClientSecret::new(provider.client_secret); - let config = server_config(); - let url = format!("{}auth/callback", config.url); - let redirect_url = match RedirectUrl::new(url) { - Ok(url) => url, + let (client_id, core_client) = make_oidc_client(callback_url, &provider).await?; + // Exchange code for ID token. + let token_response = match core_client + .exchange_code(code) + .request_async(async_http_client) + .await + { + Ok(token) => token, Err(err) => { - error!("Failed to create redirect URL: {err:?}"); - return Err(WebError::Authorization( - "Failed to create redirect URL".to_string(), - )); + return Err(WebError::Authorization(format!( + "Failed to exchange code for ID token; OpenID provider error: {err:?}" + ))); } }; + let Some(id_token) = token_response.extra_fields().id_token() else { + return Err(WebError::Authorization( + "Server did not return an ID token".to_string(), + )); + }; - Ok( - CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)) - .set_redirect_uri(redirect_url), - ) + // Verify ID token against the nonce value received in the callback. + let token_verifier = core_client + .id_token_verifier() + .require_audience_match(false); + // claims = user attributes + let token_claims = match id_token.claims(&token_verifier, &nonce) { + Ok(claims) => claims, + Err(error) => { + return Err(WebError::Authorization(format!( + "Failed to verify ID token, error: {error:?}", + ))); + } + }; + // Custom `aud` (audience) verfication. According to OpenID specification: + // "The Client MUST validate that the aud (audience) Claim contains its client_id value + // registered at the Issuer identified by the iss (issuer) Claim as an audience. The ID + // Token MUST be rejected if the ID Token does not list the Client as a valid audience, + // or if it contains additional audiences not trusted by the Client." + // But some providers, like Zitadel, send additional values in `aud`, so allow that. + let audiences = token_claims.audiences(); + if !audiences.iter().any(|aud| **aud == *client_id) { + return Err(WebError::Authorization(format!( + "Invalid OpenID claims: 'aud' must contain '{}' (found audiences: {})", + client_id.as_str(), + audiences + .iter() + .map(|aud| aud.to_string()) + .collect::>() + .join(", ") + ))); + }; + if audiences.len() > 1 { + warn!( + "OpenID claims: 'aud' should not contain these additional fields {}", + audiences + .iter() + .filter(|&aud| **aud != *client_id) + .map(|aud| aud.to_string()) + .collect::>() + .join(", ") + ); + } + + // Only email and username is required for user lookup and login + let email = token_claims.email().ok_or(WebError::BadRequest( + "Email not found in the information returned from provider. Make sure your provider is \ + configured correctly and that you have granted the necessary permissions to retrieve \ + such information." + .to_string(), + ))?; + + // Try to get the username from the preferred_username claim. + // If it's not there, extract it from email. + let username = if let Some(username) = token_claims.preferred_username() { + debug!("Preferred username {username:?} found in the claims, extracting username from it."); + username + } else { + debug!("Preferred username not found in the claims, extracting from email address."); + // Extract the username from the email address + let username = email.split('@').next().ok_or(WebError::BadRequest( + "Failed to extract username from email address".to_string(), + ))?; + debug!("Username extracted from email ({email:?}): {username})"); + username + }; + let username = prune_username(username); + // Check if the username is valid just in case, not everything can be handled by the pruning. + check_username(&username)?; + + // Get the *sub* claim from the token. + let sub = token_claims.subject().to_string(); + + // Handle logging in or creating user. + let settings = Settings::get_settings(pool).await?; + let user = match User::find_by_sub(pool, &sub) + .await + .map_err(|err| WebError::Authorization(err.to_string()))? + { + Some(user) => { + debug!( + "User {} is trying to log in using an OpenID provider.", + user.username + ); + // Make sure the user is not disabled + if !user.is_active { + debug!("User {} tried to log in, but is disabled", user.username); + return Err(WebError::Authorization("User is disabled".into())); + } + user + } + None => { + if let Some(mut user) = User::find_by_email(pool, email).await? { + if !user.is_active { + debug!("User {} tried to log in, but is disabled", user.username); + return Err(WebError::Authorization("User is disabled".into())); + } + // User with the same email already exists, merge the accounts + info!( + "User with email address {} is logging in through OpenID Connect for the \ + first time and we've found an existing account with the same email \ + address. Merging accounts.", + user.email + ); + user.openid_sub = Some(sub); + user.save(pool).await?; + user + } else { + // Check if the user should be created if they don't exist (default: true) + if !settings.openid_create_account { + warn!( + "User with email address {} is trying to log in through OpenID Connect \ + for the first time, but the account creation is disabled. An enrollment \ + should performed.", + email.as_str() + ); + return Err(WebError::Authorization( + "User not found, but needs to be created in order to login using OIDC." + .into(), + )); + } + + info!( + "User {username} is logging in through OpenID Connect for the first time and \ + there is no account with the same email address ({}). Creating a new account.", + email.as_str() + ); + // Check if user with the same username already exists (usernames are unique). + if User::find_by_username(pool, &username).await?.is_some() { + return Err(WebError::Authorization(format!( + "User with username {username} already exists" + ))); + } + + // Extract all necessary information from the token needed to create an account + let given_name_error = + "Given name not found in the information returned from provider. Make sure \ + your provider is configured correctly and that you have granted the \ + necessary permissions to retrieve such information."; + let given_name = token_claims + .given_name() + // 'None' gets you the default value from a localized claim. + // Otherwise you would need to pass a locale. + .and_then(|claim| claim.get(None)) + .ok_or(WebError::BadRequest(given_name_error.into()))?; + let family_name_error = + "Family name not found in the information returned from provider. Make sure \ + your provider is configured correctly and that you have granted the \ + necessary permissions to retrieve such information."; + let family_name = token_claims + .family_name() + .and_then(|claim| claim.get(None)) + .ok_or(WebError::BadRequest(family_name_error.into()))?; + let phone = token_claims.phone_number(); + + let mut user = User::new( + username.to_string(), + None, + family_name.to_string(), + given_name.to_string(), + email.to_string(), + phone.map(|v| v.to_string()), + ); + user.openid_sub = Some(sub); + user.save(pool).await? + } + } + }; + + Ok(user) } -pub async fn get_auth_info( +pub(crate) async fn get_auth_info( _license: LicenseInfo, private_cookies: PrivateCookieJar, State(appstate): State, ) -> Result<(PrivateCookieJar, ApiResponse), WebError> { - let client = make_oidc_client(&appstate.pool).await?; + let provider = OpenIdProvider::get_current(&appstate.pool).await?; + let Some(provider) = provider else { + return Err(WebError::ObjectNotFound( + "OpenID provider not set".to_string(), + )); + }; + + let config = server_config(); + let (_client_id, client) = make_oidc_client(config.callback_url(), &provider).await?; // Generate the redirect URL and the values needed later for callback authenticity verification let (authorize_url, csrf_state, nonce) = client .authorize_url( - AuthenticationFlow::::Implicit(false), + CoreAuthenticationFlow::AuthorizationCode, CsrfToken::new_random, Nonce::new_random, ) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".into())) + .add_scope(Scope::new("profile".into())) .url(); - let config = server_config(); - let nonce_cookie = Cookie::build(("nonce", nonce.secret().clone())) - .domain( - config - .cookie_domain - .clone() - .expect("Cookie domain not found"), - ) + let cookie_domain = config + .cookie_domain + .as_ref() + .expect("Cookie domain not found"); + let nonce_cookie = Cookie::build((NONCE_COOKIE_NAME, nonce.secret().clone())) + .domain(cookie_domain) .path("/api/v1/openid/callback") .http_only(true) .same_site(SameSite::Strict) - .secure(true) - .max_age(Duration::days(1)) + .secure(!config.cookie_insecure) + .max_age(COOKIE_MAX_AGE) .build(); - let csrf_cookie = Cookie::build(("csrf", csrf_state.secret().clone())) - .domain( - config - .cookie_domain - .clone() - .expect("Cookie domain not found"), - ) + let csrf_cookie = Cookie::build((CSRF_COOKIE_NAME, csrf_state.secret().clone())) + .domain(cookie_domain) .path("/api/v1/openid/callback") .http_only(true) .same_site(SameSite::Strict) - .secure(true) - .max_age(Duration::days(1)) + .secure(!config.cookie_insecure) + .max_age(COOKIE_MAX_AGE) .build(); - let private_cookies = private_cookies.add(nonce_cookie).add(csrf_cookie); Ok(( @@ -152,6 +327,7 @@ pub async fn get_auth_info( json: json!( { "url": authorize_url, + "button_display_name": provider.display_name } ), status: StatusCode::OK, @@ -159,23 +335,17 @@ pub async fn get_auth_info( )) } -#[derive(Deserialize, Serialize, Debug)] -pub struct AuthenticationResponse { - id_token: IdToken< - EmptyAdditionalClaims, - CoreGenderClaim, - CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, - CoreJsonWebKeyType, - >, +#[derive(Deserialize)] +pub(crate) struct AuthenticationResponse { + code: AuthorizationCode, state: CsrfToken, } -pub async fn auth_callback( +pub(crate) async fn auth_callback( _license: LicenseInfo, cookies: CookieJar, - private_cookies: PrivateCookieJar, - user_agent: Option>, + mut private_cookies: PrivateCookieJar, + user_agent: TypedHeader, forwarded_for_ip: Option, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, @@ -183,264 +353,94 @@ pub async fn auth_callback( ) -> Result<(CookieJar, PrivateCookieJar, ApiResponse), WebError> { debug!("Auth callback received, logging in user..."); - // Get the nonce and csrf cookies, we need them to verify the callback - let mut private_cookies = private_cookies; + // Get the nonce and CSRF cookies, we need them to verify the callback let cookie_nonce = private_cookies - .get("nonce") - .ok_or(WebError::Authorization( - "Nonce cookie not found".to_string(), - ))? + .get(NONCE_COOKIE_NAME) + .ok_or(WebError::Authorization("Nonce cookie not found".into()))? .value_trimmed() .to_string(); let cookie_csrf = private_cookies - .get("csrf") - .ok_or(WebError::BadRequest("CSRF cookie not found".to_string()))? + .get(CSRF_COOKIE_NAME) + .ok_or(WebError::BadRequest("CSRF cookie not found".into()))? .value_trimmed() .to_string(); - // Verify the csrf token - if *payload.state.secret() != cookie_csrf { - return Err(WebError::Authorization("CSRF token mismatch".to_string())); + // Verify the CSRF token + if payload.state.secret() != &cookie_csrf { + return Err(WebError::Authorization("CSRF token mismatch".into())); }; - // Get the ID token and verify it against the nonce value received in the callback - let client = make_oidc_client(&appstate.pool).await?; - let nonce = Nonce::new(cookie_nonce); - let token_verifier = client.id_token_verifier(); - let id_token = payload.id_token; - private_cookies = private_cookies - .remove(Cookie::from("nonce")) - .remove(Cookie::from("csrf")); - - // claims = user attributes - let token_claims = match id_token.claims(&token_verifier, &nonce) { - Ok(claims) => claims, - Err(error) => { - return Err(WebError::Authorization(format!( - "Failed to verify ID token, error: {error:?}", - ))); - } - }; - - // Only email and username is required for user lookup and login - let email = token_claims.email().ok_or(WebError::BadRequest( - "Email not found in the information returned from provider. Make sure your provider is configured correctly and that you have granted the necessary permissions to retrieve such information.".to_string(), - ))?; - - // Try to get the username from the preferred_username claim, if it's not there, extract it from the email - let username = if let Some(username) = token_claims.preferred_username() { - debug!("Preferred username {username:?} found in the claims, extracting username from it."); - let mut username: String = username.to_string(); - username = prune_username(&username); - // Check if the username is valid just in case, not everything can be handled by the pruning - check_username(&username)?; - debug!("Username extracted from preferred_username: {}", username); - username - } else { - debug!("Preferred username not found in the claims, extracting from email address."); - // Extract the username from the email address - let username = email.split('@').next().ok_or(WebError::BadRequest( - "Failed to extract username from email address".to_string(), - ))?; - let username = prune_username(username); - // Check if the username is valid just in case, not everything can be handled by the pruning - check_username(&username)?; - debug!("Username extracted from email ({:?}): {})", email, username); - username - }; - - // Get the sub claim from the token - let sub = token_claims.subject().to_string(); - - // Handle logging in or creating the user - let settings = Settings::get_settings(&appstate.pool).await?; - let user = match User::find_by_sub(&appstate.pool, &sub).await { - Ok(Some(user)) => { - debug!( - "User {} is trying to log in using an OpenID provider.", - user.username - ); - // Make sure the user is not disabled - if !user.is_active { - debug!("User {} tried to log in, but is disabled", user.username); - return Err(WebError::Authorization("User is disabled".to_string())); - } - user - } - Ok(None) => { - if let Some(mut user) = User::find_by_email(&appstate.pool, email).await? { - // User with the same email already exists, merge the accounts - info!( - "User with email address {} is logging in through OpenID Connect for the first time and we've found an existing account with the same email address. Merging accounts.", - user.email - ); - user.openid_sub = Some(sub); - user.save(&appstate.pool).await?; - user - } else { - // Check if the user should be created if they don't exist (default: true) - if settings.openid_create_account { - info!( - "User {} is logging in through OpenID Connect for the first time and there is no account with the same email address ({}). Creating a new account.", - username, email.as_str() - ); - // Check if user with the same username already exists - // Usernames are unique - if User::find_by_username(&appstate.pool, &username) - .await? - .is_some() - { - return Err(WebError::Authorization(format!( - "User with username {username} already exists" - ))); - } + .remove(Cookie::from(NONCE_COOKIE_NAME)) + .remove(Cookie::from(CSRF_COOKIE_NAME)); - // Extract all necessary information from the token needed to create an account - let given_name_error = - "Given name not found in the information returned from provider. Make sure your provider is configured correctly and that you have granted the necessary permissions to retrieve such information."; - let given_name = token_claims - .given_name() - .ok_or(WebError::BadRequest(given_name_error.to_string()))? - // 'None' gets you the default value from a localized claim. Otherwise you would need to pass a locale. - .get(None) - .ok_or(WebError::BadRequest(given_name_error.to_string()))?; - let family_name_error = - "Family name not found in the information returned from provider. Make sure your provider is configured correctly and that you have granted the necessary permissions to retrieve such information."; - let family_name = token_claims - .family_name() - .ok_or(WebError::BadRequest(family_name_error.to_string()))? - .get(None) - .ok_or(WebError::BadRequest(family_name_error.to_string()))?; - let phone = token_claims.phone_number(); + let config = server_config(); + let user = user_from_claims( + &appstate.pool, + Nonce::new(cookie_nonce), + payload.code, + config.callback_url(), + ) + .await?; - let mut user = User::new( - username.to_string(), - None, - family_name.to_string(), - given_name.to_string(), - email.to_string(), - phone.map(|v| v.to_string()), - ); - user.openid_sub = Some(sub); - user.save(&appstate.pool).await? - } else { - warn!( - "User with email address {} is trying to log in through OpenID Connect for the first time, but the account creation is disabled. They should perform an enrollment first.", - email.as_str() - ); - return Err(WebError::Authorization( - "User not found. The user needs to be created first in order to login using OIDC.".to_string(), - )); - } - } - } - Err(e) => { - return Err(WebError::Authorization(e.to_string())); - } - }; + let ip_address = forwarded_for_ip.map_or(insecure_ip, |v| v.0); + let (session, user_info, mfa_info) = create_session( + &appstate.pool, + &appstate.mail_tx, + ip_address, + user_agent.as_str(), + &user, + ) + .await?; - // Handle creating the session - let ip_address = forwarded_for_ip.map_or(insecure_ip, |v| v.0).to_string(); - let user_agent_string = match user_agent { - Some(value) => value.to_string(), - None => String::new(), - }; - let agent = parse_user_agent(&appstate.user_agent_parser, &user_agent_string); - let device_info = agent.clone().map(|v| get_user_agent_device(&v)); - Session::delete_expired(&appstate.pool).await?; - let session = Session::new( - user.id, - SessionState::PasswordVerified, - ip_address.clone(), - device_info, - ); - session.save(&appstate.pool).await?; - let max_age = Duration::seconds(server_config().auth_cookie_timeout.as_secs() as i64); - let config = server_config(); - let auth_cookie = Cookie::build((SESSION_COOKIE_NAME, session.id.clone())) - .domain( - config - .cookie_domain - .clone() - .expect("Cookie domain not found"), - ) + let max_age = Duration::seconds(config.auth_cookie_timeout.as_secs() as i64); + let cookie_domain = config + .cookie_domain + .as_ref() + .expect("Cookie domain not found"); + let auth_cookie = Cookie::build((SESSION_COOKIE_NAME, session.id)) + .domain(cookie_domain) .path("/") .http_only(true) .secure(!config.cookie_insecure) .same_site(SameSite::Lax) .max_age(max_age); let cookies = cookies.add(auth_cookie); - let login_event_type = "AUTHENTICATION".to_string(); - info!("Authenticated user {username} with external OpenID provider."); - if user.mfa_enabled { - debug!("User {username} has MFA enabled, sending MFA info for further authentication."); - if let Some(mfa_info) = MFAInfo::for_user(&appstate.pool, &user).await? { - check_new_device_login( - &appstate.pool, - &appstate.mail_tx, - &session, - &user, - ip_address, - login_event_type, - agent, - ) - .await?; - Ok(( - cookies, - private_cookies, - ApiResponse { - json: json!(mfa_info), - status: StatusCode::CREATED, - }, - )) - } else { - error!("Couldn't fetch MFA info for user {username} with MFA enabled"); - Err(WebError::DbError("MFA info read error".into())) - } - } else { - debug!("User {username} has MFA disabled, returning user info for login."); - let user_info = UserInfo::from_user(&appstate.pool, &user).await?; - - check_new_device_login( - &appstate.pool, - &appstate.mail_tx, - &session, - &user, - ip_address, - login_event_type, - agent, - ) - .await?; + if let Some(mfa_info) = mfa_info { + return Ok(( + cookies, + private_cookies, + ApiResponse { + json: json!(mfa_info), + status: StatusCode::CREATED, + }, + )); + } - if let Some(openid_cookie) = private_cookies.get(SIGN_IN_COOKIE_NAME) { - debug!("Found openid session cookie, returning the redirect URL stored in the cookie."); - let redirect_url = openid_cookie.value().to_string(); - Ok(( - cookies, - private_cookies.remove(openid_cookie), - ApiResponse { - json: json!(AuthResponse { - user: user_info, - url: Some(redirect_url) - }), - status: StatusCode::OK, - }, - )) + if let Some(user_info) = user_info { + let url = if let Some(openid_cookie) = private_cookies.get(SIGN_IN_COOKIE_NAME) { + debug!("Found OpenID session cookie, returning the redirect URL stored in it."); + let url = openid_cookie.value().to_string(); + private_cookies = private_cookies.remove(openid_cookie); + Some(url) } else { - debug!("No OpenID session found, proceeding with login to defguard."); - Ok(( - cookies, - private_cookies, - ApiResponse { - json: json!(AuthResponse { - user: user_info, - url: None, - }), - status: StatusCode::OK, - }, - )) - } + debug!("No OpenID session found, proceeding with login to Defguard."); + None + }; + + Ok(( + cookies, + private_cookies, + ApiResponse { + json: json!(AuthResponse { + user: user_info, + url + }), + status: StatusCode::OK, + }, + )) + } else { + unimplemented!("Impossible to get here"); } } diff --git a/src/enterprise/handlers/openid_providers.rs b/src/enterprise/handlers/openid_providers.rs index d7d170fbb..a477683b2 100644 --- a/src/enterprise/handlers/openid_providers.rs +++ b/src/enterprise/handlers/openid_providers.rs @@ -19,6 +19,7 @@ pub struct AddProviderData { base_url: String, client_id: String, client_secret: String, + display_name: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -28,12 +29,19 @@ pub struct DeleteProviderData { impl AddProviderData { #[must_use] - pub fn new(name: &str, base_url: &str, client_id: &str, client_secret: &str) -> Self { + pub fn new( + name: &str, + base_url: &str, + client_id: &str, + client_secret: &str, + display_name: Option<&str>, + ) -> Self { Self { name: name.to_string(), base_url: base_url.to_string(), client_id: client_id.to_string(), client_secret: client_secret.to_string(), + display_name: display_name.map(|s| s.to_string()), } } } @@ -51,6 +59,7 @@ pub async fn add_openid_provider( provider_data.base_url, provider_data.client_id, provider_data.client_secret, + provider_data.display_name, ) .upsert(&appstate.pool) .await?; diff --git a/src/grpc/enrollment.rs b/src/grpc/enrollment.rs index c308d7523..af88dd099 100644 --- a/src/grpc/enrollment.rs +++ b/src/grpc/enrollment.rs @@ -1,10 +1,7 @@ -use std::sync::Arc; - use ipnetwork::IpNetwork; use sqlx::{PgPool, Transaction}; use tokio::sync::{broadcast::Sender, mpsc::UnboundedSender}; use tonic::Status; -use uaparser::UserAgentParser; use super::{ proto::{ @@ -37,7 +34,6 @@ pub(super) struct EnrollmentServer { pool: PgPool, wireguard_tx: Sender, mail_tx: UnboundedSender, - user_agent_parser: Arc, ldap_feature_active: bool, } @@ -47,7 +43,6 @@ impl EnrollmentServer { pool: PgPool, wireguard_tx: Sender, mail_tx: UnboundedSender, - user_agent_parser: Arc, ) -> Self { // FIXME: check if LDAP feature is enabled let ldap_feature_active = true; @@ -55,7 +50,6 @@ impl EnrollmentServer { pool, wireguard_tx, mail_tx, - user_agent_parser, ldap_feature_active, } } @@ -257,12 +251,12 @@ impl EnrollmentServer { if let Some(info) = req_device_info { ip_address = info.ip_address.unwrap_or_default(); let user_agent = info.user_agent.unwrap_or_default(); - device_info = get_device_info(&self.user_agent_parser, &user_agent); + device_info = Some(get_device_info(&user_agent)); } else { ip_address = String::new(); device_info = None; } - debug!("Ip address {}, device info {device_info:?}", ip_address); + debug!("IP address {}, device info {device_info:?}", ip_address); // check if password is strong enough debug!("Verifying password strength for user activation process."); @@ -399,12 +393,12 @@ impl EnrollmentServer { if let Some(info) = req_device_info { ip_address = info.ip_address.unwrap_or_default(); let user_agent = info.user_agent.unwrap_or_default(); - device_info = get_device_info(&self.user_agent_parser, &user_agent); + device_info = Some(get_device_info(&user_agent)); } else { ip_address = String::new(); device_info = None; } - debug!("Ip address {}, device info {device_info:?}", ip_address); + debug!("IP address {}, device info {device_info:?}", ip_address); debug!( "Validating pubkey {} for device creation process for user {}({:?})", diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 802eb398c..4398c1fdb 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -10,6 +10,7 @@ use std::{ }; use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc}; +use openidconnect::{core::CoreAuthenticationFlow, AuthorizationCode, CsrfToken, Nonce, Scope}; use reqwest::Url; use serde::Serialize; #[cfg(feature = "worker")] @@ -27,7 +28,6 @@ use tonic::{ transport::{Certificate, ClientTlsConfig, Endpoint, Identity, Server, ServerTlsConfig}, Code, Status, }; -use uaparser::UserAgentParser; use uuid::Uuid; #[cfg(feature = "wireguard")] @@ -46,9 +46,14 @@ use self::{ }; use crate::{ auth::failed_login::FailedLoginMap, - db::{AppEvent, Id, Settings}, + db::{ + models::enrollment::{Token, ENROLLMENT_TOKEN_TYPE}, + AppEvent, Id, Settings, + }, enterprise::{ - db::models::enterprise_settings::EnterpriseSettings, grpc::polling::PollingServer, + db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, + grpc::polling::PollingServer, + handlers::openid_login::{make_oidc_client, user_from_claims}, is_enterprise_enabled, }, handlers::mail::send_gateway_disconnected_email, @@ -74,7 +79,10 @@ pub(crate) mod proto { tonic::include_proto!("defguard.proxy"); } -use proto::{core_request, proxy_client::ProxyClient, CoreError, CoreResponse}; +use proto::{ + core_request, proxy_client::ProxyClient, AuthCallbackResponse, AuthInfoResponse, CoreError, + CoreResponse, +}; // Helper struct used to handle gateway state // gateways are grouped by network @@ -350,20 +358,15 @@ pub async fn run_grpc_bidi_stream( pool: PgPool, wireguard_tx: Sender, mail_tx: UnboundedSender, - user_agent_parser: Arc, ) -> Result<(), anyhow::Error> { let config = server_config(); // TODO: merge the two - let enrollment_server = EnrollmentServer::new( - pool.clone(), - wireguard_tx.clone(), - mail_tx.clone(), - user_agent_parser, - ); + let enrollment_server = + EnrollmentServer::new(pool.clone(), wireguard_tx.clone(), mail_tx.clone()); let password_reset_server = PasswordResetServer::new(pool.clone(), mail_tx.clone()); let mut client_mfa_server = ClientMfaServer::new(pool.clone(), mail_tx, wireguard_tx); - let polling_server = PollingServer::new(pool); + let polling_server = PollingServer::new(pool.clone()); let endpoint = Endpoint::from_shared(config.proxy_url.as_deref().unwrap())?; let endpoint = endpoint @@ -538,6 +541,100 @@ pub async fn run_grpc_bidi_stream( } } } + Some(core_request::Payload::AuthInfo(request)) => { + if !is_enterprise_enabled() { + warn!("Enterprise license required"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::FailedPrecondition as i32, + message: "no valid license".into(), + })) + } else if let Ok(redirect_url) = Url::parse(&request.redirect_url) { + if let Some(provider) = OpenIdProvider::get_current(&pool).await? { + if let Ok((_client_id, client)) = + make_oidc_client(redirect_url, &provider).await + { + let (url, csrf_token, nonce) = client + .authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("profile".to_string())) + .url(); + Some(core_response::Payload::AuthInfo(AuthInfoResponse { + url: url.into(), + csrf_token: csrf_token.secret().to_owned(), + nonce: nonce.secret().to_owned(), + button_display_name: provider.display_name, + })) + } else { + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "failed to build OIDC client".into(), + })) + } + } else { + error!("Failed to get current OpenID provider"); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "failed to get current OpenID provider".into(), + })) + } + } else { + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message: "invalid redirect URL".into(), + })) + } + } + Some(core_request::Payload::AuthCallback(request)) => { + let callback_url = Url::parse(&request.callback_url).unwrap(); + let code = AuthorizationCode::new(request.code); + match user_from_claims( + &pool, + Nonce::new(request.nonce), + code, + callback_url, + ) + .await + { + Ok(user) => { + user.clear_unused_enrollment_tokens(&pool).await?; + debug!("Cleared unused tokens for {}.", user.username); + debug!( + "Creating a new desktop activation token for user {} as a result of proxy OpenID auth callback.", + user.username + ); + let config = server_config(); + let desktop_configuration = Token::new( + user.id, + Some(user.id), + Some(user.email), + config.enrollment_token_timeout.as_secs(), + Some(ENROLLMENT_TOKEN_TYPE.to_string()), + ); + debug!("Saving a new desktop configuration token..."); + desktop_configuration.save(&pool).await?; + debug!("Saved desktop configuration token. Responding to proxy with the token."); + + Some(core_response::Payload::AuthCallback( + AuthCallbackResponse { + url: config.enrollment_url.clone().into(), + token: desktop_configuration.id, + }, + )) + } + Err(err) => { + let message = format!("OpenID auth error {err}"); + error!(message); + Some(core_response::Payload::CoreError(CoreError { + status_code: Code::Internal as i32, + message, + })) + } + } + } // Reply without payload. None => None, }; diff --git a/src/grpc/password_reset.rs b/src/grpc/password_reset.rs index 6a5516c16..b61a51790 100644 --- a/src/grpc/password_reset.rs +++ b/src/grpc/password_reset.rs @@ -124,7 +124,7 @@ impl PasswordResetServer { config.password_reset_token_timeout.as_secs(), Some(PASSWORD_RESET_TOKEN_TYPE.to_string()), ); - enrollment.save(&mut transaction).await?; + enrollment.save(&mut *transaction).await?; transaction.commit().await.map_err(|_| { error!("Failed to commit transaction"); diff --git a/src/handlers/auth.rs b/src/handlers/auth.rs index 953ed2f34..ef139def5 100644 --- a/src/handlers/auth.rs +++ b/src/handlers/auth.rs @@ -1,3 +1,5 @@ +use std::net::IpAddr; + use axum::{ extract::{Json, State}, http::StatusCode, @@ -12,8 +14,10 @@ use axum_extra::{ TypedHeader, }; use serde_json::json; -use sqlx::types::Uuid; +use sqlx::{types::Uuid, PgPool}; use time::Duration; +use tokio::sync::mpsc::UnboundedSender; +use uaparser::Parser; use webauthn_rs::prelude::PublicKeyCredential; use webauthn_rs_proto::options::CollectedClientData; @@ -27,7 +31,9 @@ use crate::{ failed_login::{check_username, log_failed_login_attempt}, SessionInfo, }, - db::{MFAInfo, MFAMethod, Session, SessionState, Settings, User, UserInfo, Wallet, WebAuthn}, + db::{ + Id, MFAInfo, MFAMethod, Session, SessionState, Settings, User, UserInfo, Wallet, WebAuthn, + }, error::WebError, handlers::{ mail::{ @@ -35,18 +41,93 @@ use crate::{ }, SIGN_IN_COOKIE_NAME, }, - headers::{check_new_device_login, get_user_agent_device, parse_user_agent}, + headers::{check_new_device_login, get_user_agent_device, USER_AGENT_PARSER}, ldap::utils::user_from_ldap, + mail::Mail, server_config, }; +/// Common functionality for `authenticate()` and `auth_callback()`. +/// Returns either `AuthResponse` or `MFAInfo`. +pub(crate) async fn create_session( + pool: &PgPool, + mail_tx: &UnboundedSender, + ip_address: IpAddr, + user_agent: &str, + user: &User, +) -> Result<(Session, Option, Option), WebError> { + let agent = USER_AGENT_PARSER.parse(user_agent); + let device_info = get_user_agent_device(&agent); + debug!("Cleaning up expired sessions..."); + Session::delete_expired(pool).await?; + debug!("Expired sessions cleaned up"); + + debug!("Creating new session for user {}", user.username); + let session = Session::new( + user.id, + SessionState::PasswordVerified, + ip_address.to_string(), + Some(device_info), + ); + session.save(pool).await?; + debug!("New session created for user {}", user.username); + + let login_event_type = "AUTHENTICATION".to_string(); + + info!("Authenticated user {}", user.username); + if user.mfa_enabled { + debug!( + "User {} has MFA enabled, sending MFA info for further authentication.", + user.username + ); + if let Some(mfa_info) = MFAInfo::for_user(pool, &user).await? { + check_new_device_login( + pool, + mail_tx, + &session, + &user, + ip_address.to_string(), + login_event_type, + agent, + ) + .await?; + Ok((session, None, Some(mfa_info))) + } else { + error!( + "Couldn't fetch MFA info for user {} with MFA enabled", + user.username + ); + Err(WebError::DbError("MFA info read error".into())) + } + } else { + debug!( + "User {} has MFA disabled, returning user info for login.", + user.username + ); + let user_info = UserInfo::from_user(pool, &user).await?; + + check_new_device_login( + pool, + mail_tx, + &session, + &user, + ip_address.to_string(), + login_event_type, + agent, + ) + .await?; + + Ok((session, Some(user_info), None)) + } +} + /// For successful login, return: /// * 200 with MFA disabled /// * 201 with MFA enabled when additional authentication factor is required -pub async fn authenticate( +pub(crate) async fn authenticate( cookies: CookieJar, - private_cookies: PrivateCookieJar, - user_agent: Option>, + mut private_cookies: PrivateCookieJar, + user_agent: TypedHeader, forwarded_for_ip: Option, InsecureClientIp(insecure_ip): InsecureClientIp, State(appstate): State, @@ -115,37 +196,24 @@ pub async fn authenticate( } }; - let ip_address = forwarded_for_ip.map_or(insecure_ip, |v| v.0).to_string(); - let user_agent_string = match user_agent { - Some(value) => value.to_string(), - None => String::new(), - }; - let agent = parse_user_agent(&appstate.user_agent_parser, &user_agent_string); - let device_info = agent.clone().map(|v| get_user_agent_device(&v)); - - debug!("Cleaning up expired sessions..."); - Session::delete_expired(&appstate.pool).await?; - debug!("Expired sessions cleaned up"); - - debug!("Creating new session for user {username}"); - let session = Session::new( - user.id, - SessionState::PasswordVerified, - ip_address.clone(), - device_info, - ); - session.save(&appstate.pool).await?; - debug!("New session created for user {username}"); + let ip_address = forwarded_for_ip.map_or(insecure_ip, |v| v.0); + let (session, user_info, mfa_info) = create_session( + &appstate.pool, + &appstate.mail_tx, + ip_address, + user_agent.as_str(), + &user, + ) + .await?; let max_age = Duration::seconds(server_config().auth_cookie_timeout.as_secs() as i64); let config = server_config(); + let cookie_domain = config + .cookie_domain + .as_ref() + .expect("Cookie domain not found"); let auth_cookie = Cookie::build((SESSION_COOKIE_NAME, session.id.clone())) - .domain( - config - .cookie_domain - .clone() - .expect("Cookie domain not found"), - ) + .domain(cookie_domain) .path("/") .http_only(true) .secure(!config.cookie_insecure) @@ -153,75 +221,41 @@ pub async fn authenticate( .max_age(max_age); let cookies = cookies.add(auth_cookie); - let login_event_type = "AUTHENTICATION".to_string(); + if let Some(mfa_info) = mfa_info { + return Ok(( + cookies, + private_cookies, + ApiResponse { + json: json!(mfa_info), + status: StatusCode::CREATED, + }, + )); + } - info!("Authenticated user {username}"); - if user.mfa_enabled { - if let Some(mfa_info) = MFAInfo::for_user(&appstate.pool, &user).await? { - check_new_device_login( - &appstate.pool, - &appstate.mail_tx, - &session, - &user, - ip_address, - login_event_type, - agent, - ) - .await?; - Ok(( - cookies, - private_cookies, - ApiResponse { - json: json!(mfa_info), - status: StatusCode::CREATED, - }, - )) + if let Some(user_info) = user_info { + let url = if let Some(openid_cookie) = private_cookies.get(SIGN_IN_COOKIE_NAME) { + debug!("Found OpenID session cookie, returning the redirect URL stored in it."); + let url = openid_cookie.value().to_string(); + private_cookies = private_cookies.remove(openid_cookie); + Some(url) } else { - error!("Couldn't fetch MFA info for user {username} with MFA enabled"); - Err(WebError::DbError("MFA info read error".into())) - } + debug!("No OpenID session found, proceeding with login to Defguard."); + None + }; + + Ok(( + cookies, + private_cookies, + ApiResponse { + json: json!(AuthResponse { + user: user_info, + url + }), + status: StatusCode::OK, + }, + )) } else { - let user_info = UserInfo::from_user(&appstate.pool, &user).await?; - - check_new_device_login( - &appstate.pool, - &appstate.mail_tx, - &session, - &user, - ip_address, - login_event_type, - agent, - ) - .await?; - - if let Some(openid_cookie) = private_cookies.get(SIGN_IN_COOKIE_NAME) { - debug!("Found openid session cookie."); - let redirect_url = openid_cookie.value().to_string(); - Ok(( - cookies, - private_cookies.remove(openid_cookie), - ApiResponse { - json: json!(AuthResponse { - user: user_info, - url: Some(redirect_url) - }), - status: StatusCode::OK, - }, - )) - } else { - debug!("No OpenID session found"); - Ok(( - cookies, - private_cookies, - ApiResponse { - json: json!(AuthResponse { - user: user_info, - url: None, - }), - status: StatusCode::OK, - }, - )) - } + unimplemented!("Impossible to get here"); } } diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 55edbf4da..945cb07c8 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -333,8 +333,8 @@ impl From for WebHook { } } -/// Return type needed to know if user came from openid flow -/// with optional url to redirect him later if yes +/// Return type needed for knowing if a user came from OpenID flow. +/// If so, fill in the optional URL field to redirect him later. #[derive(Serialize, Deserialize)] pub struct AuthResponse { pub user: UserInfo, @@ -349,7 +349,11 @@ pub async fn user_for_admin_or_self( username: &str, ) -> Result, WebError> { if session.user.username == username || session.is_admin { - debug!("The user meets one or both of these conditions: 1) the user from the current session has admin privileges, 2) the user performs this operation on themself."); + debug!( + "The user meets one or both of these conditions: \ + 1) the user from the current session has admin privileges, \ + 2) the user performs this operation on themself." + ); if let Some(user) = User::find_by_username(pool, username).await? { debug!("User {} has been found in database.", user.username); Ok(user) diff --git a/src/handlers/user.rs b/src/handlers/user.rs index 9365d61ea..4bb32f35e 100644 --- a/src/handlers/user.rs +++ b/src/handlers/user.rs @@ -936,7 +936,7 @@ pub async fn reset_password( config.password_reset_token_timeout.as_secs(), Some(PASSWORD_RESET_TOKEN_TYPE.to_string()), ); - enrollment.save(&mut transaction).await?; + enrollment.save(&mut *transaction).await?; let mail = Mail { to: user.email.clone(), diff --git a/src/headers.rs b/src/headers.rs index f550edfad..03a0f9b05 100644 --- a/src/headers.rs +++ b/src/headers.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, sync::Arc}; +use std::{borrow::Borrow, sync::LazyLock}; use sqlx::PgPool; use tokio::sync::mpsc::UnboundedSender; @@ -11,30 +11,15 @@ use crate::{ templates::TemplateError, }; -#[must_use] -pub fn create_user_agent_parser() -> Arc { +pub(crate) static USER_AGENT_PARSER: LazyLock = LazyLock::new(|| { let regexes = include_bytes!("../user_agent_header_regexes.yaml"); - Arc::new(UserAgentParser::from_bytes(regexes).expect("Parser creation failed")) -} + UserAgentParser::from_bytes(regexes).expect("Parser creation failed") +}); #[must_use] -pub(crate) fn parse_user_agent<'a>( - user_parser: &UserAgentParser, - user_agent: &'a str, -) -> Option> { - if user_agent.is_empty() { - None - } else { - Some(user_parser.parse(user_agent)) - } -} - -#[must_use] -pub(crate) fn get_device_info( - user_agent_parser: &UserAgentParser, - user_agent: &str, -) -> Option { - parse_user_agent(user_agent_parser, user_agent).map(|v| get_user_agent_device(&v)) +pub(crate) fn get_device_info(user_agent: &str) -> String { + let client = USER_AGENT_PARSER.parse(user_agent); + get_user_agent_device(&client) } #[must_use] @@ -69,18 +54,7 @@ pub(crate) fn get_user_agent_device(user_agent_client: &Client) -> String { format!("{device_type}, OS: {device_os}") } -#[must_use] -pub(crate) fn get_device_login_event( - user_id: Id, - ip_address: String, - event_type: String, - user_agent_client: Option, -) -> Option { - user_agent_client - .map(|client| get_user_agent_device_login_data(user_id, ip_address, event_type, &client)) -} - -pub(crate) fn get_user_agent_device_login_data( +fn get_user_agent_device_login_data( user_id: Id, ip_address: String, event_type: String, @@ -114,22 +88,22 @@ pub(crate) async fn check_new_device_login( user: &User, ip_address: String, event_type: String, - agent: Option>, + agent: Client<'_>, ) -> Result<(), TemplateError> { - if let Some(device_login_event) = get_device_login_event(user.id, ip_address, event_type, agent) + let device_login_event = + get_user_agent_device_login_data(user.id, ip_address, event_type, &agent); + + if let Ok(Some(created_device_login_event)) = device_login_event + .check_if_device_already_logged_in(pool) + .await { - if let Ok(Some(created_device_login_event)) = device_login_event - .check_if_device_already_logged_in(pool) - .await - { - send_new_device_login_email( - &user.email, - mail_tx, - session, - created_device_login_event.created, - ) - .await?; - } + send_new_device_login_email( + &user.email, + mail_tx, + session, + created_device_login_event.created, + ) + .await?; } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index f3f2f3deb..aecf9e6ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,6 @@ use tokio::{ }; use tower_http::trace::{DefaultOnResponse, TraceLayer}; use tracing::Level; -use uaparser::UserAgentParser; use utoipa::{ openapi::security::{ApiKey, ApiKeyValue, SecurityScheme}, Modify, OpenApi, @@ -283,7 +282,6 @@ pub fn build_webapp( worker_state: Arc>, gateway_state: Arc>, pool: PgPool, - user_agent_parser: Arc, failed_logins: Arc>, ) -> Router { let webapp: Router = Router::new() @@ -496,7 +494,6 @@ pub fn build_webapp( webhook_rx, wireguard_tx, mail_tx, - user_agent_parser, failed_logins, )) .layer( @@ -522,7 +519,6 @@ pub async fn run_web_server( wireguard_tx: Sender, mail_tx: UnboundedSender, pool: PgPool, - user_agent_parser: Arc, failed_logins: Arc>, ) -> Result<(), anyhow::Error> { let webapp = build_webapp( @@ -533,7 +529,6 @@ pub async fn run_web_server( worker_state, gateway_state, pool, - user_agent_parser, failed_logins, ); info!("Started web services"); diff --git a/tests/auth.rs b/tests/auth.rs index ebcb932b1..0270c45d9 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -3,7 +3,7 @@ mod common; use std::{str::FromStr, time::SystemTime}; use chrono::NaiveDateTime; -use claims::assert_err; +use claims::{assert_err, assert_ok}; use common::fetch_user_details; use defguard::{ auth::{TOTP_CODE_DIGITS, TOTP_CODE_VALIDITY_PERIOD}, @@ -360,9 +360,13 @@ async fn test_email_mfa() { // check email was sent let mail = mail_rx.try_recv().unwrap(); - assert_err!(mail_rx.try_recv()); + assert_ok!(mail_rx.try_recv()); assert_eq!(mail.to, "h.potter@hogwart.edu.uk"); - assert_eq!(mail.subject, "Your Multi-Factor Authentication Activation"); + assert_eq!( + mail.subject, + "Defguard: new device logged in to your account" + ); + // assert_eq!(mail.subject, "Your Multi-Factor Authentication Activation"); // resend setup email let response = client.post("/api/v1/auth/email/init").send().await; @@ -422,11 +426,11 @@ async fn test_email_mfa() { // check that code email was sent let mail = mail_rx.try_recv().unwrap(); - assert_err!(mail_rx.try_recv()); + assert_ok!(mail_rx.try_recv()); assert_eq!(mail.to, "h.potter@hogwart.edu.uk"); assert_eq!( mail.subject, - "Your Multi-Factor Authentication Code for Login" + "Defguard: new device logged in to your account" // "Your Multi-Factor Authentication Code for Login" ); // resend code diff --git a/tests/common/client.rs b/tests/common/client.rs index af1374d09..8bb3b170b 100644 --- a/tests/common/client.rs +++ b/tests/common/client.rs @@ -4,7 +4,7 @@ use axum::{serve, Router}; use bytes::Bytes; use reqwest::{ cookie::{Cookie, Jar}, - header::{HeaderMap, HeaderName}, + header::{HeaderMap, HeaderName, HeaderValue, USER_AGENT}, redirect::Policy, Body, Client, StatusCode, Url, }; @@ -35,7 +35,11 @@ impl TestClient { let jar = Arc::new(Jar::default()); + let mut headers = HeaderMap::new(); + headers.insert(USER_AGENT, HeaderValue::from_static("test/0.0")); + let client = Client::builder() + .default_headers(headers) .redirect(Policy::none()) .cookie_provider(jar.clone()) .build() diff --git a/tests/common/mod.rs b/tests/common/mod.rs index c0eef8a4e..6e81d7948 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -10,7 +10,6 @@ use defguard::{ enterprise::license::{set_cached_license, License}, grpc::{GatewayMap, WorkerState}, handlers::Auth, - headers::create_user_agent_parser, mail::Mail, SERVER_CONFIG, }; @@ -123,8 +122,6 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie let failed_logins = FailedLoginMap::new(); let failed_logins = Arc::new(Mutex::new(failed_logins)); - let user_agent_parser = create_user_agent_parser(); - let license = License::new( "test_customer".to_string(), true, @@ -167,7 +164,6 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie worker_state, gateway_state, pool, - user_agent_parser, failed_logins, ); diff --git a/tests/openid.rs b/tests/openid.rs index 2394335c9..9c0e02997 100644 --- a/tests/openid.rs +++ b/tests/openid.rs @@ -333,7 +333,7 @@ async fn test_openid_flow() { assert!(location.contains("value1=")); assert!(location.contains("value2=")); - // // test allow false + // test allow false let response = client .post(format!( "/api/v1/oauth/authorize?\ @@ -398,6 +398,8 @@ async fn http_client( }) } +static FAKE_REDIRECT_URI: &str = "http://test.server.tnt:12345/"; + #[tokio::test] async fn test_openid_authorization_code() { let (pool, config) = init_test_db().await; @@ -420,7 +422,7 @@ async fn test_openid_authorization_code() { assert_eq!(response.status(), StatusCode::OK); let oauth2client = NewOpenIDClient { name: "My test client".into(), - redirect_uri: vec!["http://test.server.tnt:12345/".into()], + redirect_uri: vec![FAKE_REDIRECT_URI.into()], scope: vec!["openid".into()], enabled: true, }; @@ -441,7 +443,7 @@ async fn test_openid_authorization_code() { let client_secret = ClientSecret::new(oauth2client.client_secret); let core_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)) - .set_redirect_uri(RedirectUrl::new("http://test.server.tnt:12345/".into()).unwrap()); + .set_redirect_uri(RedirectUrl::new(FAKE_REDIRECT_URI.into()).unwrap()); let (authorize_url, _csrf_state, nonce) = core_client .authorize_url( AuthenticationFlow::::AuthorizationCode, @@ -470,7 +472,7 @@ async fn test_openid_authorization_code() { .to_str() .unwrap(); let (location, query) = location.split_once('?').unwrap(); - assert_eq!(location, "http://test.server.tnt:12345/"); + assert_eq!(location, FAKE_REDIRECT_URI); let auth_response: AuthenticationResponse = serde_qs::from_str(query).unwrap(); // exchange authorization code for token @@ -525,7 +527,7 @@ async fn test_openid_authorization_code_with_pkce() { assert_eq!(response.status(), StatusCode::OK); let oauth2client = NewOpenIDClient { name: "My test client".into(), - redirect_uri: vec!["http://test.server.tnt:12345/".into()], + redirect_uri: vec![FAKE_REDIRECT_URI.into()], scope: vec!["openid".into()], enabled: true, }; @@ -544,7 +546,7 @@ async fn test_openid_authorization_code_with_pkce() { let client_secret = ClientSecret::new(oauth2client.client_secret); let core_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)) - .set_redirect_uri(RedirectUrl::new("http://test.server.tnt:12345/".into()).unwrap()); + .set_redirect_uri(RedirectUrl::new(FAKE_REDIRECT_URI.into()).unwrap()); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (authorize_url, _csrf_state, nonce) = core_client .authorize_url( @@ -575,7 +577,7 @@ async fn test_openid_authorization_code_with_pkce() { .to_str() .unwrap(); let (location, query) = location.split_once('?').unwrap(); - assert_eq!(location, "http://test.server.tnt:12345/"); + assert_eq!(location, FAKE_REDIRECT_URI); let auth_response: AuthenticationResponse = serde_qs::from_str(query).unwrap(); // exchange authorization code for token diff --git a/tests/openid_login.rs b/tests/openid_login.rs index 7c9d8def1..23e57233e 100644 --- a/tests/openid_login.rs +++ b/tests/openid_login.rs @@ -41,6 +41,7 @@ async fn test_openid_providers() { "https://accounts.google.com", "client_id", "client_secret", + Some("display_name"), ); let response = client @@ -60,7 +61,7 @@ async fn test_openid_providers() { url: String, } - let provider: UrlResponse = response.json::().await; + let provider: UrlResponse = response.json().await; let url = Url::parse(&provider.url).unwrap(); @@ -70,11 +71,12 @@ async fn test_openid_providers() { .unwrap(); assert_eq!(client_id.1, "client_id"); - let nonce = url.query_pairs().find(|(key, _)| key == "nonce"); + let mut query_pairs = url.query_pairs(); + let nonce = query_pairs.clone().find(|(key, _)| key == "nonce"); assert!(nonce.is_some()); - let state = url.query_pairs().find(|(key, _)| key == "state"); + let state = query_pairs.clone().find(|(key, _)| key == "state"); assert!(state.is_some()); - let redirect_uri = url.query_pairs().find(|(key, _)| key == "redirect_uri"); + let redirect_uri = query_pairs.find(|(key, _)| key == "redirect_uri"); assert!(redirect_uri.is_some()); // Test that the endpoint is forbidden when the license is expired diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index cf73b36bd..cb169e2ea 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -982,6 +982,11 @@ const en: BaseTranslation = { helper: 'Base URL of your OpenID provider, e.g. https://accounts.google.com. Make sure to check our documentation for more information and examples.', }, + display_name: { + label: 'Display Name', + helper: + "Name of the OpenID provider to display on the login's page button. If not provided, the button will display generic 'Login with OIDC' text.", + }, }, }, }, @@ -1602,6 +1607,7 @@ const en: BaseTranslation = { }, loginPage: { pageTitle: 'Enter your credentials', + oidcLogin: 'Sign in with', callback: { return: 'Go back to login', error: 'An error occurred during external OpenID login', diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index def7dc3d1..86102479e 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -2431,6 +2431,16 @@ type RootTranslation = { */ helper: string } + display_name: { + /** + * D​i​s​p​l​a​y​ ​N​a​m​e + */ + label: string + /** + * N​a​m​e​ ​o​f​ ​t​h​e​ ​O​p​e​n​I​D​ ​p​r​o​v​i​d​e​r​ ​t​o​ ​d​i​s​p​l​a​y​ ​o​n​ ​t​h​e​ ​l​o​g​i​n​'​s​ ​p​a​g​e​ ​b​u​t​t​o​n​.​ ​I​f​ ​n​o​t​ ​p​r​o​v​i​d​e​d​,​ ​t​h​e​ ​b​u​t​t​o​n​ ​w​i​l​l​ ​d​i​s​p​l​a​y​ ​g​e​n​e​r​i​c​ ​'​L​o​g​i​n​ ​w​i​t​h​ ​O​I​D​C​'​ ​t​e​x​t​. + */ + helper: string + } } } } @@ -3814,6 +3824,10 @@ type RootTranslation = { * E​n​t​e​r​ ​y​o​u​r​ ​c​r​e​d​e​n​t​i​a​l​s */ pageTitle: string + /** + * S​i​g​n​ ​i​n​ ​w​i​t​h + */ + oidcLogin: string callback: { /** * G​o​ ​b​a​c​k​ ​t​o​ ​l​o​g​i​n @@ -6692,6 +6706,16 @@ export type TranslationFunctions = { */ helper: () => LocalizedString } + display_name: { + /** + * Display Name + */ + label: () => LocalizedString + /** + * Name of the OpenID provider to display on the login's page button. If not provided, the button will display generic 'Login with OIDC' text. + */ + helper: () => LocalizedString + } } } } @@ -8061,6 +8085,10 @@ export type TranslationFunctions = { * Enter your credentials */ pageTitle: () => LocalizedString + /** + * Sign in with + */ + oidcLogin: () => LocalizedString callback: { /** * Go back to login diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index 84d4e03d1..8fc91cd7b 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -971,6 +971,11 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe helper: 'Podstawowy adres URL twojego dostawcy OpenID, np. https://accounts.google.com. Sprawdź naszą dokumentację, aby uzyskać więcej informacji i zobaczyć przykłady.', }, + display_name: { + label: 'Wyświetlana nazwa', + helper: + 'Nazwa dostawcy OpenID, która będzie wyświetlana na przycisku logowania. Jeśli zostawisz to pole puste, przycisk będzie miał tekst "Zaloguj przez OIDC".', + }, }, }, }, @@ -1595,6 +1600,7 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe return: 'Powrót do logowania', error: 'Wystąpił błąd podczas logowania przez zewnętrznego dostawcę OpenID', }, + oidcLogin: 'Zaloguj się przez', mfa: { title: 'Autoryzacja dwuetapowa.', controls: { diff --git a/web/src/pages/auth/Callback/Callback.tsx b/web/src/pages/auth/Callback/Callback.tsx index 86699201f..3210a30b9 100644 --- a/web/src/pages/auth/Callback/Callback.tsx +++ b/web/src/pages/auth/Callback/Callback.tsx @@ -47,9 +47,9 @@ export const OpenIDCallback = () => { }); useEffect(() => { - if (window.location.hash && window.location.hash.length > 0) { - const hashFragment = window.location.hash.substring(1); - const params = new URLSearchParams(hashFragment); + if (window.location.search && window.location.search.length > 0) { + // const hashFragment = window.location.search.substring(1); + const params = new URLSearchParams(window.location.search); // check if error occured const error = params.get('error'); @@ -60,15 +60,19 @@ export const OpenIDCallback = () => { return; } - const id_token = params.get('id_token'); + const code = params.get('code'); const state = params.get('state'); - if (id_token && state) { + if (code && state) { const data: CallbackData = { - id_token, + code, state, }; callbackMutation.mutate(data); + } else { + setError('Expected data not returned by the OpenID provider'); + toaster.error(LL.messages.error()); + return; } } // eslint-disable-next-line react-hooks/exhaustive-deps diff --git a/web/src/pages/auth/Login/Login.tsx b/web/src/pages/auth/Login/Login.tsx index 6425fd4f9..3159ae619 100644 --- a/web/src/pages/auth/Login/Login.tsx +++ b/web/src/pages/auth/Login/Login.tsx @@ -147,7 +147,10 @@ export const Login = () => { data-testid="login-form-submit" /> {enterpriseEnabled && openIdInfo && ( - + )} diff --git a/web/src/pages/auth/Login/components/OidcButtons.tsx b/web/src/pages/auth/Login/components/OidcButtons.tsx index d2e339c58..9581d55e0 100644 --- a/web/src/pages/auth/Login/components/OidcButtons.tsx +++ b/web/src/pages/auth/Login/components/OidcButtons.tsx @@ -1,13 +1,20 @@ /* eslint-disable max-len */ import './style.scss'; +import { useI18nContext } from '../../../../i18n/i18n-react'; import { Button } from '../../../../shared/defguard-ui/components/Layout/Button/Button'; import { ButtonSize, ButtonStyleVariant, } from '../../../../shared/defguard-ui/components/Layout/Button/types'; -export const OpenIdLoginButton = ({ url }: { url: string }) => { +export const OpenIdLoginButton = ({ + url, + display_name, +}: { + url: string; + display_name?: string; +}) => { const { hostname } = new URL(url); if (hostname === 'accounts.google.com') { @@ -15,7 +22,7 @@ export const OpenIdLoginButton = ({ url }: { url: string }) => { } else if (hostname === 'login.microsoftonline.com') { return ; } else { - return ; + return ; } }; @@ -24,7 +31,7 @@ const GoogleButton = ({ url }: { url: string }) => {