diff --git a/e2e/config.ts b/e2e/config.ts index 1aa2d2bfb..6ca69f66e 100644 --- a/e2e/config.ts +++ b/e2e/config.ts @@ -42,6 +42,7 @@ export const routes = { users: '/admin/users', openid: '/admin/openid', overview: '/admin/overview', + settings: '/admin/settings', }, authorize: '/api/v1/oauth/authorize', }; diff --git a/e2e/tests/externalopenid.spec.ts b/e2e/tests/externalopenid.spec.ts new file mode 100644 index 000000000..e4dcc6e0a --- /dev/null +++ b/e2e/tests/externalopenid.spec.ts @@ -0,0 +1,99 @@ +import { expect, test } from '@playwright/test'; + +import { defaultUserAdmin, routes, testsConfig, testUserTemplate } from '../config'; +import { NetworkForm, OpenIdClient, User } from '../types'; +import { apiCreateUser } from '../utils/api/users'; +import { loginBasic } from '../utils/controllers/login'; +import { logout } from '../utils/controllers/logout'; +import { copyOpenIdClientIdAndSecret } from '../utils/controllers/openid/copyClientId'; +import { CreateExternalProvider } from '../utils/controllers/openid/createExternalProvider'; +import { CreateOpenIdClient } from '../utils/controllers/openid/createOpenIdClient'; +import { createNetwork } from '../utils/controllers/vpn/createNetwork'; +import { dockerDown, dockerRestart } from '../utils/docker'; +import { waitForBase } from '../utils/waitForBase'; +import { waitForPromise } from '../utils/waitForPromise'; +import { waitForRoute } from '../utils/waitForRoute'; + +test.describe('External OIDC.', () => { + const testUser: User = { ...testUserTemplate, username: 'test' }; + + const client: OpenIdClient = { + name: 'test 01', + redirectURL: [ + 'http://localhost:8000/auth/callback', + 'http://localhost:8080/openid/callback', + ], + scopes: ['openid', 'profile', 'email'], + }; + + const testNetwork: NetworkForm = { + name: 'test network', + address: '10.10.10.1/24', + endpoint: '127.0.0.1', + port: '5055', + }; + + test.beforeEach(async ({ browser }) => { + dockerRestart(); + await CreateOpenIdClient(browser, client); + [client.clientID, client.clientSecret] = await copyOpenIdClientIdAndSecret( + browser, + client.name + ); + const context = await browser.newContext(); + const page = await context.newPage(); + await CreateExternalProvider(browser, client); + await loginBasic(page, defaultUserAdmin); + await apiCreateUser(page, testUser); + await logout(page); + await createNetwork(browser, testNetwork); + context.close(); + }); + + test.afterAll(() => { + dockerDown(); + }); + + test('Login through external oidc.', async ({ page }) => { + expect(client.clientID).toBeDefined(); + expect(client.clientSecret).toBeDefined(); + await waitForBase(page); + const oidcLoginButton = await page.getByTestId('login-oidc'); + expect(oidcLoginButton).not.toBeNull(); + expect(await oidcLoginButton.textContent()).toBe(`Sign in with ${client.name}`); + await oidcLoginButton.click(); + await page.getByTestId('login-form-username').fill(testUser.username); + await page.getByTestId('login-form-password').fill(testUser.password); + await page.getByTestId('login-form-submit').click(); + await page.getByTestId('openid-allow').click(); + await waitForRoute(page, routes.me); + const authorizedApps = await page + .getByTestId('authorized-apps') + .locator('div') + .textContent(); + expect(authorizedApps).toContain(client.name); + }); + + test('Complete enrollment through external OIDC', async ({ page }) => { + await waitForBase(page); + await page.goto(testsConfig.ENROLLMENT_URL); + await waitForPromise(2000); + await page.getByTestId('select-enrollment').click(); + await page.getByTestId('login-oidc').click(); + await page.getByTestId('login-form-username').fill(testUser.username); + await page.getByTestId('login-form-password').fill(testUser.password); + await page.getByTestId('login-form-submit').click(); + await page.getByTestId('openid-allow').click(); + const instanceUrlBox = page + .locator('div') + .filter({ hasText: /^Instance URL$/ }) + .getByRole('textbox'); + + expect(await instanceUrlBox.inputValue()).toBe('http://localhost:8080/'); + const instanceTokenBox = page + .locator('div') + .filter({ hasText: /^Token$/ }) + .getByRole('textbox'); + expect((await instanceTokenBox.inputValue()).length).toBeGreaterThan(1); + }); +}); diff --git a/e2e/tests/openid.spec.ts b/e2e/tests/openid.spec.ts index 92ec33903..de3280465 100644 --- a/e2e/tests/openid.spec.ts +++ b/e2e/tests/openid.spec.ts @@ -20,7 +20,7 @@ test.describe('Authorize OpenID client.', () => { const client: OpenIdClient = { name: 'test 01', - redirectURL: 'https://oidcdebugger.com/debug', + redirectURL: ['https://oidcdebugger.com/debug'], scopes: ['openid'], }; diff --git a/e2e/types.ts b/e2e/types.ts index 056f23040..96a86143c 100644 --- a/e2e/types.ts +++ b/e2e/types.ts @@ -57,7 +57,8 @@ export type User = { export type OpenIdClient = { name: string; clientID?: string; - redirectURL: string; + clientSecret?: string; + redirectURL: string[]; scopes: OpenIdScope[]; }; diff --git a/e2e/utils/controllers/openid/copyClientId.ts b/e2e/utils/controllers/openid/copyClientId.ts index d867c5e3e..9c0f44f02 100644 --- a/e2e/utils/controllers/openid/copyClientId.ts +++ b/e2e/utils/controllers/openid/copyClientId.ts @@ -16,3 +16,26 @@ export const copyOpenIdClientId = async (browser: Browser, clientId: number) => const id = await getPageClipboard(page); return id; }; + +export const copyOpenIdClientIdAndSecret = async ( + browser: Browser, + clientName: string +) => { + const context = await browser.newContext(); + const page = await context.newPage(); + await waitForBase(page); + await loginBasic(page, defaultUserAdmin); + await page.goto(routes.base + routes.admin.openid, { waitUntil: 'networkidle' }); + await page + .locator('div') + .filter({ + hasText: new RegExp(`^${clientName}$`), + }) + .click(); + await page.getByTestId('copy-client-id').click(); + const id = await getPageClipboard(page); + await page.locator('.variant-copy').nth(1).click(); + const secret = await getPageClipboard(page); + await context.close(); + return [id, secret]; +}; diff --git a/e2e/utils/controllers/openid/createExternalProvider.ts b/e2e/utils/controllers/openid/createExternalProvider.ts new file mode 100644 index 000000000..97a4eb3d4 --- /dev/null +++ b/e2e/utils/controllers/openid/createExternalProvider.ts @@ -0,0 +1,22 @@ +import { Browser } from 'playwright'; + +import { defaultUserAdmin, routes } from '../../../config'; +import { OpenIdClient } from '../../../types'; +import { waitForBase } from '../../waitForBase'; +import { loginBasic } from '../login'; + +export const CreateExternalProvider = async (browser: Browser, client: OpenIdClient) => { + const context = await browser.newContext(); + const page = await context.newPage(); + await waitForBase(page); + await loginBasic(page, defaultUserAdmin); + await page.goto(routes.base + routes.admin.settings, { waitUntil: 'networkidle' }); + await page.getByRole('button', { name: 'OpenID' }).click(); + await page.locator('.content-frame').click(); + await page.getByRole('button', { name: 'Custom' }).click(); + await page.getByTestId('field-base_url').fill('http://localhost:8000/'); + await page.getByTestId('field-client_id').fill(client.clientID || ''); + await page.getByTestId('field-client_secret').fill(client.clientSecret || ''); + await page.getByTestId('field-display_name').fill(client.name); + await page.getByRole('button', { name: 'Save changes' }).click(); +}; diff --git a/e2e/utils/controllers/openid/createOpenIdClient.ts b/e2e/utils/controllers/openid/createOpenIdClient.ts index 2fc057c17..b7e319a8b 100644 --- a/e2e/utils/controllers/openid/createOpenIdClient.ts +++ b/e2e/utils/controllers/openid/createOpenIdClient.ts @@ -17,7 +17,16 @@ export const CreateOpenIdClient = async (browser: Browser, client: OpenIdClient) await modalElement.waitFor({ state: 'visible' }); const modalForm = modalElement.locator('form'); await modalForm.getByTestId('field-name').type(client.name); - await modalForm.getByTestId('field-redirect_uri.0.url').type(client.redirectURL); + const urls = client.redirectURL.length; + for (let i = 0; i < urls; i++) { + const isLast = i === urls - 1; + await modalForm + .getByTestId(`field-redirect_uri.${i}.url`) + .fill(client.redirectURL[i]); + if (!isLast) { + await modalForm.locator('button:has-text("Add URL")').click(); + } + } for (const scope of client.scopes) { await modalForm.getByTestId(`field-scope-${scope}`).click(); } diff --git a/tests/common/client.rs b/tests/common/client.rs index 8bb3b170b..2f0eacb64 100644 --- a/tests/common/client.rs +++ b/tests/common/client.rs @@ -19,10 +19,7 @@ pub struct TestClient { #[allow(dead_code)] impl TestClient { #[must_use] - pub async fn new(app: Router) -> Self { - let listener = TcpListener::bind("127.0.0.1:0") - .await - .expect("Could not bind ephemeral socket"); + pub async fn new(app: Router, listener: TcpListener) -> Self { let port = listener.local_addr().unwrap().port(); tokio::spawn(async move { @@ -58,7 +55,7 @@ impl TestClient { /// /// this is useful when trying to check if Location headers in responses /// are generated correctly as Location contains an absolute URL - fn base_url(&self) -> String { + pub fn base_url(&self) -> String { let mut s = String::from("http://localhost:"); s.push_str(&self.port.to_string()); s diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 6e81d7948..8d4e2e90b 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,6 +1,9 @@ pub(crate) mod client; -use std::sync::{Arc, Mutex}; +use std::{ + str::FromStr, + sync::{Arc, Mutex}, +}; use defguard::{ auth::failed_login::FailedLoginMap, @@ -13,10 +16,11 @@ use defguard::{ mail::Mail, SERVER_CONFIG, }; -use reqwest::{header::HeaderName, StatusCode}; +use reqwest::{header::HeaderName, StatusCode, Url}; use secrecy::ExposeSecret; use serde_json::json; use sqlx::{postgres::PgConnectOptions, query, types::Uuid, PgPool}; +use tokio::net::TcpListener; use tokio::sync::{ broadcast::{self, Receiver}, mpsc::{unbounded_channel, UnboundedReceiver}, @@ -31,9 +35,17 @@ pub const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for #[allow(dead_code, clippy::declare_interior_mutable_const)] pub const X_FORWARDED_URI: HeaderName = HeaderName::from_static("x-forwarded-uri"); -pub async fn init_test_db() -> (PgPool, DefGuardConfig) { - let config = DefGuardConfig::new_test_config(); +/// Allows overriding the default DefGuard URL for tests, as during the tests, the server has a random port, making the URL unpredictable beforehand. +// TODO: Allow customizing the whole config, not just the URL +pub fn init_config(custom_defguard_url: Option<&str>) -> DefGuardConfig { + let url = custom_defguard_url.unwrap_or("http://localhost:8000"); + let mut config = DefGuardConfig::new_test_config(); + config.url = Url::from_str(url).unwrap(); let _ = SERVER_CONFIG.set(config.clone()); + config +} + +pub async fn init_test_db(config: &DefGuardConfig) -> PgPool { let opts = PgConnectOptions::new() .host(&config.database_host) .port(config.database_port) @@ -57,9 +69,9 @@ pub async fn init_test_db() -> (PgPool, DefGuardConfig) { ) .await; - initialize_users(&pool, &config).await; + initialize_users(&pool, config).await; - (pool, config) + pool } async fn initialize_users(pool: &PgPool, config: &DefGuardConfig) { @@ -112,7 +124,11 @@ impl ClientState { } } -pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClient, ClientState) { +pub async fn make_base_client( + pool: PgPool, + config: DefGuardConfig, + listener: TcpListener, +) -> (TestClient, ClientState) { let (tx, rx) = unbounded_channel::(); let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone()))); let (wg_tx, wg_rx) = broadcast::channel::(16); @@ -124,7 +140,7 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie let license = License::new( "test_customer".to_string(), - true, + false, // Some(Utc.with_ymd_and_hms(2030, 1, 1, 0, 0, 0).unwrap()), // Permanent license None, @@ -167,13 +183,35 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie failed_logins, ); - (TestClient::new(webapp).await, client_state) + (TestClient::new(webapp, listener).await, client_state) +} + +/// Make an instance url based on the listener +fn get_test_url(listener: &TcpListener) -> String { + let port = listener.local_addr().unwrap().port(); + format!("http://localhost:{}", port) } #[allow(dead_code)] pub async fn make_test_client() -> (TestClient, ClientState) { - let (pool, config) = init_test_db().await; - make_base_client(pool, config).await + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Could not bind ephemeral socket"); + let config = init_config(None); + let pool = init_test_db(&config).await; + make_base_client(pool, config, listener).await +} + +/// Makes a test client with a DEFGUARD_URL set to the random url of the listener. +/// This is useful when the instance's url real url needs to match the one set in the ENV variable. +#[allow(dead_code)] +pub async fn make_test_client_with_real_url() -> (TestClient, ClientState) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Could not bind ephemeral socket"); + let config = init_config(Some(&get_test_url(&listener))); + let pool = init_test_db(&config).await; + make_base_client(pool, config, listener).await } #[allow(dead_code)] @@ -183,6 +221,8 @@ pub async fn fetch_user_details(client: &TestClient, username: &str) -> UserDeta response.json().await } +/// Exceeds enterprise free version limits by creating more than 1 network +#[allow(dead_code)] pub async fn exceed_enterprise_limits(client: &TestClient) { let auth = Auth::new("admin", "pass123"); client.post("/api/v1/auth").json(&auth).send().await; diff --git a/tests/openid.rs b/tests/openid.rs index 9c0e02997..7ad5c7724 100644 --- a/tests/openid.rs +++ b/tests/openid.rs @@ -2,6 +2,7 @@ use std::str::FromStr; use axum::http::header::ToStrError; use claims::assert_err; +use common::init_config; use defguard::{ config::DefGuardConfig, db::{ @@ -26,6 +27,7 @@ use reqwest::{ use rsa::RsaPrivateKey; use serde::Deserialize; use sqlx::PgPool; +use tokio::net::TcpListener; mod common; use self::common::{client::TestClient, init_test_db, make_base_client, make_test_client}; @@ -36,7 +38,10 @@ async fn make_client() -> TestClient { } async fn make_client_v2(pool: PgPool, config: DefGuardConfig) -> TestClient { - let (client, _) = make_base_client(pool, config).await; + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Could not bind ephemeral socket"); + let (client, _) = make_base_client(pool, config, listener).await; client } @@ -402,7 +407,8 @@ static FAKE_REDIRECT_URI: &str = "http://test.server.tnt:12345/"; #[tokio::test] async fn test_openid_authorization_code() { - let (pool, config) = init_test_db().await; + let config = init_config(None); + let pool = init_test_db(&config).await; let issuer_url = IssuerUrl::from_url(config.url.clone()); let client = make_client_v2(pool.clone(), config.clone()).await; @@ -505,7 +511,8 @@ async fn test_openid_authorization_code() { #[tokio::test] async fn test_openid_authorization_code_with_pkce() { - let (pool, mut config) = init_test_db().await; + let mut config = init_config(None); + let pool = init_test_db(&config).await; let mut rng = rand::thread_rng(); config.openid_signing_key = RsaPrivateKey::new(&mut rng, 2048).ok(); diff --git a/tests/openid_login.rs b/tests/openid_login.rs index 23e57233e..c25a331b5 100644 --- a/tests/openid_login.rs +++ b/tests/openid_login.rs @@ -1,7 +1,9 @@ use chrono::{Duration, Utc}; -use common::exceed_enterprise_limits; +use common::{exceed_enterprise_limits, make_test_client, make_test_client_with_real_url}; +use defguard::db::{models::oauth2client::OAuth2Client, Id}; +use defguard::enterprise::license::get_cached_license; use defguard::{ - config::DefGuardConfig, + db::models::NewOpenIDClient, enterprise::{ handlers::openid_providers::AddProviderData, license::{set_cached_license, License}, @@ -9,20 +11,18 @@ use defguard::{ handlers::Auth, }; use reqwest::{StatusCode, Url}; -use serde::Deserialize; -use sqlx::PgPool; +use serde::{Deserialize, Serialize}; mod common; -use self::common::{client::TestClient, make_base_client, make_test_client}; +use self::common::client::TestClient; async fn make_client() -> TestClient { let (client, _) = make_test_client().await; client } -#[allow(dead_code)] -async fn make_client_v2(pool: PgPool, config: DefGuardConfig) -> TestClient { - let (client, _) = make_base_client(pool, config).await; +async fn make_client_with_real_url() -> TestClient { + let (client, _) = make_test_client_with_real_url().await; client } @@ -89,3 +89,153 @@ async fn test_openid_providers() { let response = client.get("/api/v1/openid/auth_info").send().await; assert_eq!(response.status(), StatusCode::FORBIDDEN); } + +#[tokio::test] +async fn test_openid_login() { + // Test setup + let client = make_client_with_real_url().await; + let auth = Auth::new("admin", "pass123"); + let response = client.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + let url = client.base_url(); + + // Add an OpenID client + let redirect_uri = format!("{}/auth/callback", &url); + let openid_client = NewOpenIDClient { + name: "Defguard".into(), + redirect_uri: vec![redirect_uri], + scope: vec!["openid".into(), "email".into(), "profile".into()], + enabled: true, + }; + let response = client + .post("/api/v1/oauth") + .json(&openid_client) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let response = client.get("/api/v1/oauth").send().await; + assert_eq!(response.status(), StatusCode::OK); + let openid_clients: Vec> = response.json().await; + assert_eq!(openid_clients.len(), 1); + let openid_client = openid_clients.first().unwrap(); + assert_eq!(openid_client.name, "Defguard"); + + // Add the provider (ourselves) + let (secret, id) = ( + openid_client.client_secret.clone(), + openid_client.client_id.clone(), + ); + let provider_data = AddProviderData::new( + "Custom", + format!("{}/", &url).as_str(), + id.to_string().as_str(), + &secret, + Some("Defguard"), + ); + let response = client + .post("/api/v1/openid/provider") + .json(&provider_data) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + + // Logout to make sure we start from a clean slate + client.post("/api/v1/auth/logout").send().await; + + // Get the provider's authorization endpoint (and button display name) + let response = client.get("/api/v1/openid/auth_info").send().await; + assert_eq!(response.status(), StatusCode::OK); + #[derive(Deserialize, Debug)] + struct AuthInfoResponse { + button_display_name: String, + url: Url, + } + let response_body: AuthInfoResponse = response.json().await; + assert_eq!(response_body.button_display_name, "Defguard"); + + // Begin OIDC login at the provider's authorization endpoint + let url = format!( + "{}?{}", + response_body.url.path(), + response_body.url.query().unwrap() + ); + let response = client.get(&url).send().await; + assert_eq!(response.status(), StatusCode::FOUND); + + // A user should now be redirected to the login page + #[derive(Deserialize, Debug)] + struct LoginResponse { + url: String, + } + let response = client.post("/api/v1/auth").json(&auth).send().await; + let login_response: LoginResponse = response.json().await; + + // During the flow, the user may be first redirected to a consent page, simualte that here + let url = Url::parse(&login_response.url).unwrap(); + let path = url.path(); + let query = url.query().unwrap(); + let url = format!("{}?{}", path, query); + let response = client.get(&url).send().await; + assert_eq!(response.status(), StatusCode::FOUND); + let location = response.headers().get("location").unwrap(); + let location = location.to_str().unwrap(); + assert!(location.starts_with("/consent")); + + // Consent to everything by adding the allow=true query parameter and sending a post request this time + let url = Url::parse(&login_response.url).unwrap(); + let mut query_pairs = url + .query_pairs() + .into_owned() + .collect::>(); + query_pairs.push(("allow".to_string(), "true".to_string())); + let pairs = query_pairs + .iter() + .map(|(key, value)| format!("{}={}", key, value)) + .collect::>() + .join("&"); + let path = format!("{}?{}", url.path(), pairs); + let response = client.post(&path).send().await; + assert_eq!(response.status(), StatusCode::FOUND); + + // logout to make sure the session won't be carried over after the callback later + client.post("/api/v1/auth/logout").send().await; + + // Extract callback data from the response's location header + let location = response.headers().get("location").unwrap(); + let location = location.to_str().unwrap(); + let url = Url::parse(location).unwrap(); + let query_pairs = url + .query_pairs() + .into_owned() + .collect::>(); + let code = query_pairs + .iter() + .find(|(key, _)| key == "code") + .unwrap() + .1 + .clone(); + let state = query_pairs + .iter() + .find(|(key, _)| key == "state") + .unwrap() + .1 + .clone(); + + // Post the callback with the data inside a json payload + #[derive(Serialize, Debug)] + struct AuthResponse { + code: String, + state: String, + } + let auth_response = AuthResponse { code, state }; + let response = client + .post("/api/v1/openid/callback") + .json(&auth_response) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + + // Am I logged in? + let response = client.get("/api/v1/me").send().await; + assert_eq!(response.status(), StatusCode::OK); +}