diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index b1cb629fc..27c9b12fa 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -23,7 +23,7 @@ pub(crate) mod auth; pub(crate) mod forward_auth; pub(crate) mod group; pub(crate) mod mail; -pub(crate) mod network_devices; +pub mod network_devices; #[cfg(feature = "openid")] pub(crate) mod openid_clients; #[cfg(feature = "openid")] diff --git a/src/handlers/network_devices.rs b/src/handlers/network_devices.rs index b3e8a41f0..ab59d003e 100644 --- a/src/handlers/network_devices.rs +++ b/src/handlers/network_devices.rs @@ -143,7 +143,6 @@ pub async fn get_network_device( pub(crate) async fn list_network_devices( _admin_role: AdminRole, - session: SessionInfo, State(appstate): State, ) -> ApiResult { debug!("Listing all network devices"); @@ -173,11 +172,11 @@ pub(crate) async fn list_network_devices( #[derive(Serialize, Deserialize, Debug)] pub struct AddNetworkDevice { - name: String, - description: Option, - location_id: i64, - assigned_ip: String, - wireguard_pubkey: String, + pub name: String, + pub description: Option, + pub location_id: i64, + pub assigned_ip: String, + pub wireguard_pubkey: String, } #[derive(Serialize)] diff --git a/tests/wireguard_network_devices.rs b/tests/wireguard_network_devices.rs new file mode 100644 index 000000000..db0c987ab --- /dev/null +++ b/tests/wireguard_network_devices.rs @@ -0,0 +1,274 @@ +mod common; + +use std::{net::IpAddr, str::FromStr}; + +use defguard::db::{Device, GatewayEvent, Id, WireguardNetwork}; +use defguard::handlers::{network_devices::AddNetworkDevice, Auth}; +use ipnetwork::IpNetwork; +use matches::assert_matches; +use reqwest::StatusCode; +use serde::Deserialize; +use serde_json::json; +use serde_json::Value; + +use self::common::make_test_client; + +fn make_network() -> Value { + json!({ + "name": "network", + "address": "10.1.1.1/24", + "port": 55555, + "endpoint": "192.168.4.14", + "allowed_ips": "10.1.1.0/24", + "dns": "1.1.1.1", + "allowed_groups": [], + "mfa_enabled": false, + "keepalive_interval": 25, + "peer_disconnect_threshold": 180 + }) +} + +fn make_second_network() -> Value { + json!({ + "name": "network-2", + "address": "10.6.1.1/24", + "port": 55555, + "endpoint": "192.168.4.14", + "allowed_ips": "10.6.1.0/24", + "dns": "1.1.1.1", + "allowed_groups": [], + "mfa_enabled": false, + "keepalive_interval": 25, + "peer_disconnect_threshold": 180 + }) +} + +#[tokio::test] +async fn test_network_devices() { + let (client, client_state) = make_test_client().await; + + let mut wg_rx = client_state.wireguard_rx; + + let auth = Auth::new("admin", "pass123"); + let response = &client.post("/api/v1/auth").json(&auth).send().await; + assert_eq!(response.status(), StatusCode::OK); + + // create networks + let response = client + .post("/api/v1/network") + .json(&make_network()) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let network_1: WireguardNetwork = response.json().await; + assert_eq!(network_1.name, "network"); + let event = wg_rx.try_recv().unwrap(); + assert_matches!(event, GatewayEvent::NetworkCreated(..)); + let response = client + .post("/api/v1/network") + .json(&make_second_network()) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let network_2: WireguardNetwork = response.json().await; + assert_eq!(network_2.name, "network-2"); + let event = wg_rx.try_recv().unwrap(); + assert_matches!(event, GatewayEvent::NetworkCreated(..)); + + // ip suggestions + let response = client.get("/api/v1/device/network/ip/1").send().await; + assert_eq!(response.status(), StatusCode::OK); + let res = response.json::().await; + let ip = res["ip"].as_str().unwrap(); + let ip = ip.parse::().unwrap(); + let net_ip = IpAddr::from_str("10.1.1.1").unwrap(); + let network_range = IpNetwork::new(net_ip, 24).unwrap(); + assert!(network_range.contains(ip)); + + // checking whether ip is valid/availble + #[derive(Deserialize)] + struct IpCheckRes { + available: bool, + valid: bool, + } + let ip_check = json!( + { + "ip": "10.1.1.2".to_string(), + } + ); + let response = client + .post("/api/v1/device/network/ip/1") + .json(&ip_check) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + let res = response.json::().await; + assert!(res.available); + assert!(res.valid); + + let ip_check = json!( + { + "ip": "10.1.1.0".to_string(), + } + ); + let response = client + .post("/api/v1/device/network/ip/1") + .json(&ip_check) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + let res = response.json::().await; + assert!(!res.available); + assert!(res.valid); + + let ip_check = json!( + { + "ip": "10.1.1.1".to_string(), + } + ); + let response = client + .post("/api/v1/device/network/ip/1") + .json(&ip_check) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + let res = response.json::().await; + assert!(!res.available); + assert!(res.valid); + + let ip_check = json!( + { + "ip": "10.1.1.abc".to_string(), + } + ); + let response = client + .post("/api/v1/device/network/ip/1") + .json(&ip_check) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + let res = response.json::().await; + assert!(!res.available); + assert!(!res.valid); + + // make network device (manual, WireGuard flow) + let network_device = AddNetworkDevice { + name: "device-1".into(), + wireguard_pubkey: "LQKsT6/3HWKuJmMulH63R8iK+5sI8FyYEL6WDIi6lQU=".into(), + assigned_ip: ip.to_string(), + location_id: 1, + description: None, + }; + let response = client + .post("/api/v1/device/network") + .json(&network_device) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let json = response.json::().await; + let device_id = json["device"]["id"].as_i64().unwrap(); + let configured = json["device"]["configured"].as_bool().unwrap(); + let config_text = json["config"]["config"].as_str().unwrap(); + assert!(configured); + let event = wg_rx.try_recv().unwrap(); + assert_matches!(event, GatewayEvent::DeviceCreated(..)); + + // download WG config + let response = client.get("/api/v1/device/network/1/config").send().await; + assert_eq!(response.status(), StatusCode::OK); + let response_config = response.text().await; + assert_eq!(response_config, config_text); + + // edit the device + let modify_device = json!({ + "name": "device-1", + "description": "new description", + "assigned_ip": "10.1.1.3" + }); + let response = client + .put(format!("/api/v1/device/network/{}", device_id)) + .json(&modify_device) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + let json = response.json::().await; + let description = json["description"].as_str().unwrap(); + let assigned_ip = json["assigned_ip"].as_str().unwrap(); + assert_eq!(description, "new description"); + assert_eq!( + assigned_ip, + IpAddr::from_str("10.1.1.3").unwrap().to_string() + ); + let device = Device::find_by_id(&client_state.pool, device_id as i64) + .await + .unwrap() + .unwrap(); + assert_eq!(device.name, "device-1"); + assert_eq!(device.description, Some("new description".to_string())); + let event = wg_rx.try_recv().unwrap(); + assert_matches!(event, GatewayEvent::DeviceModified(..)); + + // Make sure the device is only in the selected network + let device_networks = device + .find_device_networks(&client_state.pool) + .await + .unwrap(); + assert_eq!(device_networks.len(), 1); + assert_eq!(network_1.id, device_networks[0].id); + + // Try making cli "enrollment" token for that device + let response = client + .post("/api/v1/device/network/start_cli/1") + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let json = response.json::().await; + let token = json["enrollment_token"].as_str().unwrap(); + assert_eq!(token.len(), 32); + let enrollment_url = json["enrollment_url"].as_str().unwrap(); + assert_eq!(enrollment_url, "http://localhost:8080/"); + + // Enrollment flow for 2nd device + let setup_start = json!( + { + "name": "device-2", + "description": "new description", + "assigned_ip": "10.1.1.10", + "location_id": 1, + } + ); + let response = client + .post("/api/v1/device/network/start_cli") + .json(&setup_start) + .send() + .await; + assert_eq!(response.status(), StatusCode::CREATED); + let json = response.json::().await; + let token = json["enrollment_token"].as_str().unwrap(); + assert_eq!(token.len(), 32); + let enrollment_url = json["enrollment_url"].as_str().unwrap(); + assert_eq!(enrollment_url, "http://localhost:8080/"); + let device = Device::find_by_id(&client_state.pool, 2) + .await + .unwrap() + .unwrap(); + assert!(!device.configured); + assert_eq!(device.name, "device-2"); + let device_network = device + .find_device_networks(&client_state.pool) + .await + .unwrap(); + assert_eq!(device_network.len(), 1); + assert_eq!(device_network[0].id, network_1.id); + + // Deleting the device + let response = client + .delete(format!("/api/v1/device/network/{}", device_id)) + .send() + .await; + assert_eq!(response.status(), StatusCode::OK); + let device = Device::find_by_id(&client_state.pool, device_id as i64) + .await + .unwrap(); + assert!(device.is_none()); +}