diff --git a/Cargo.lock b/Cargo.lock index 0822026fe5b..efa753841bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5742,6 +5742,7 @@ name = "nym-wireguard" version = "0.1.0" dependencies = [ "base64 0.21.7", + "chrono", "dashmap", "defguard_wireguard_rs", "ip_network", @@ -5750,6 +5751,7 @@ dependencies = [ "nym-network-defaults", "nym-task", "nym-wireguard-types", + "thiserror", "tokio", "tokio-stream", "x25519-dalek", diff --git a/common/authenticator-requests/src/v1/request.rs b/common/authenticator-requests/src/v1/request.rs index 15389998281..f8490312180 100644 --- a/common/authenticator-requests/src/v1/request.rs +++ b/common/authenticator-requests/src/v1/request.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use nym_sphinx::addressing::Recipient; -use nym_wireguard_types::{GatewayClient, InitMessage}; +use nym_wireguard_types::{GatewayClient, InitMessage, PeerPublicKey}; use serde::{Deserialize, Serialize}; use crate::make_bincode_serializer; @@ -57,6 +57,19 @@ impl AuthenticatorRequest { ) } + pub fn new_query_request(peer_public_key: PeerPublicKey, reply_to: Recipient) -> (Self, u64) { + let request_id = generate_random(); + ( + Self { + version: VERSION, + data: AuthenticatorRequestData::QueryBandwidth(peer_public_key), + reply_to, + request_id, + }, + request_id, + ) + } + pub fn to_bytes(&self) -> Result, bincode::Error> { use bincode::Options; make_bincode_serializer().serialize(self) @@ -67,4 +80,5 @@ impl AuthenticatorRequest { pub enum AuthenticatorRequestData { Initial(InitMessage), Final(GatewayClient), + QueryBandwidth(PeerPublicKey), } diff --git a/common/authenticator-requests/src/v1/response.rs b/common/authenticator-requests/src/v1/response.rs index 925ece872e5..83c5a71d509 100644 --- a/common/authenticator-requests/src/v1/response.rs +++ b/common/authenticator-requests/src/v1/response.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use nym_sphinx::addressing::Recipient; -use nym_wireguard_types::registration::RegistrationData; +use nym_wireguard_types::registration::{RegistrationData, RegistredData, RemainingBandwidthData}; use serde::{Deserialize, Serialize}; use crate::make_bincode_serializer; @@ -33,10 +33,31 @@ impl AuthenticatorResponse { } } - pub fn new_registered(reply_to: Recipient, request_id: u64) -> Self { + pub fn new_registered( + registred_data: RegistredData, + reply_to: Recipient, + request_id: u64, + ) -> Self { Self { version: VERSION, data: AuthenticatorResponseData::Registered(RegisteredResponse { + reply: registred_data, + reply_to, + request_id, + }), + reply_to, + } + } + + pub fn new_remaining_bandwidth( + remaining_bandwidth_data: Option, + reply_to: Recipient, + request_id: u64, + ) -> Self { + Self { + version: VERSION, + data: AuthenticatorResponseData::RemainingBandwidth(RemainingBandwidthResponse { + reply: remaining_bandwidth_data, reply_to, request_id, }), @@ -64,6 +85,7 @@ impl AuthenticatorResponse { match &self.data { AuthenticatorResponseData::PendingRegistration(response) => Some(response.request_id), AuthenticatorResponseData::Registered(response) => Some(response.request_id), + AuthenticatorResponseData::RemainingBandwidth(response) => Some(response.request_id), } } } @@ -72,6 +94,7 @@ impl AuthenticatorResponse { pub enum AuthenticatorResponseData { PendingRegistration(PendingRegistrationResponse), Registered(RegisteredResponse), + RemainingBandwidth(RemainingBandwidthResponse), } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -85,4 +108,12 @@ pub struct PendingRegistrationResponse { pub struct RegisteredResponse { pub request_id: u64, pub reply_to: Recipient, + pub reply: RegistredData, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RemainingBandwidthResponse { + pub request_id: u64, + pub reply_to: Recipient, + pub reply: Option, } diff --git a/common/wireguard-types/src/registration.rs b/common/wireguard-types/src/registration.rs index 5645c4f5de7..ae5a758d905 100644 --- a/common/wireguard-types/src/registration.rs +++ b/common/wireguard-types/src/registration.rs @@ -27,12 +27,15 @@ pub type HmacSha256 = Hmac; pub type Nonce = u64; pub type Taken = Option; +pub const BANDWIDTH_CAP_PER_DAY: u64 = 1024 * 1024 * 1024; // 1 GB + #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type", rename_all = "camelCase")] #[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))] pub enum ClientMessage { Initial(InitMessage), Final(GatewayClient), + Query(PeerPublicKey), } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -60,6 +63,19 @@ pub struct RegistrationData { pub wg_port: u16, } +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RegistredData { + pub pub_key: PeerPublicKey, + pub private_ip: IpAddr, + pub wg_port: u16, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct RemainingBandwidthData { + pub available_bandwidth: u64, + pub suspended: bool, +} + /// Client that wants to register sends its PublicKey bytes mac digest encrypted with a DH shared secret. /// Gateway/Nym node can then verify pub_key payload using the same process #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/common/wireguard/Cargo.toml b/common/wireguard/Cargo.toml index 0a5e615d4cb..51a0b115a1e 100644 --- a/common/wireguard/Cargo.toml +++ b/common/wireguard/Cargo.toml @@ -12,6 +12,7 @@ license.workspace = true [dependencies] base64 = { workspace = true } +chrono = { workspace = true } dashmap = { workspace = true } defguard_wireguard_rs = { workspace = true } # The latest version on crates.io at the time of writing this (6.0.0) has a @@ -24,5 +25,6 @@ nym-crypto = { path = "../crypto", features = ["asymmetric"] } nym-network-defaults = { path = "../network-defaults" } nym-task = { path = "../task" } nym-wireguard-types = { path = "../wireguard-types" } +thiserror = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "net", "io-util"] } tokio-stream = { workspace = true } diff --git a/common/wireguard/src/error.rs b/common/wireguard/src/error.rs new file mode 100644 index 00000000000..da7452614fc --- /dev/null +++ b/common/wireguard/src/error.rs @@ -0,0 +1,11 @@ +// Copyright 2024 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("peers in wireguard don't match with in-memory ")] + PeerMismatch, + + #[error("{0}")] + Defguard(#[from] defguard_wireguard_rs::error::WireguardInterfaceError), +} diff --git a/common/wireguard/src/lib.rs b/common/wireguard/src/lib.rs index 3b1b0ea199d..6244320e540 100644 --- a/common/wireguard/src/lib.rs +++ b/common/wireguard/src/lib.rs @@ -1,3 +1,6 @@ +// Copyright 2024 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + #![cfg_attr(not(target_os = "linux"), allow(dead_code))] // #![warn(clippy::pedantic)] // #![warn(clippy::expect_used)] @@ -6,13 +9,14 @@ use dashmap::DashMap; use defguard_wireguard_rs::{host::Peer, key::Key, net::IpAddrMask, WGApi}; use nym_crypto::asymmetric::encryption::KeyPair; -use nym_wireguard_types::{Config, Error, GatewayClient, GatewayClientRegistry}; -use peer_controller::PeerControlMessage; +use nym_wireguard_types::{Config, Error, GatewayClient, GatewayClientRegistry, PeerPublicKey}; +use peer_controller::PeerControlRequest; use std::sync::Arc; -use tokio::sync::mpsc::{self, UnboundedReceiver}; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; const WG_TUN_NAME: &str = "nymwg"; +pub(crate) mod error; pub mod peer_controller; pub struct WgApiWrapper { @@ -39,14 +43,14 @@ pub struct WireguardGatewayData { config: Config, keypair: Arc, client_registry: Arc, - peer_tx: mpsc::UnboundedSender, + peer_tx: UnboundedSender, } impl WireguardGatewayData { pub fn new( config: Config, keypair: Arc, - ) -> (Self, mpsc::UnboundedReceiver) { + ) -> (Self, UnboundedReceiver) { let (peer_tx, peer_rx) = mpsc::unbounded_channel(); ( WireguardGatewayData { @@ -75,20 +79,26 @@ impl WireguardGatewayData { let mut peer = Peer::new(Key::new(client.pub_key.to_bytes())); peer.allowed_ips .push(IpAddrMask::new(client.private_ip, 32)); - let msg = PeerControlMessage::AddPeer(peer); + let msg = PeerControlRequest::AddPeer(peer); self.peer_tx.send(msg).map_err(|_| Error::PeerModifyStopped) } pub fn remove_peer(&self, client: &GatewayClient) -> Result<(), Error> { let key = Key::new(client.pub_key().to_bytes()); - let msg = PeerControlMessage::RemovePeer(key); + let msg = PeerControlRequest::RemovePeer(key); + self.peer_tx.send(msg).map_err(|_| Error::PeerModifyStopped) + } + + pub fn query_bandwidth(&self, peer_public_key: PeerPublicKey) -> Result<(), Error> { + let key = Key::new(peer_public_key.to_bytes()); + let msg = PeerControlRequest::QueryBandwidth(key); self.peer_tx.send(msg).map_err(|_| Error::PeerModifyStopped) } } pub struct WireguardData { pub inner: WireguardGatewayData, - pub peer_rx: UnboundedReceiver, + pub peer_rx: UnboundedReceiver, } /// Start wireguard device @@ -96,6 +106,7 @@ pub struct WireguardData { pub async fn start_wireguard( task_client: nym_task::TaskClient, wireguard_data: WireguardData, + control_tx: UnboundedSender, ) -> Result, Box> { use base64::{prelude::BASE64_STANDARD, Engine}; use defguard_wireguard_rs::{InterfaceConfiguration, WireguardInterfaceApi}; @@ -135,13 +146,13 @@ pub async fn start_wireguard( wg_api.configure_peer_routing(&[catch_all_peer])?; let wg_api = std::sync::Arc::new(WgApiWrapper::new(wg_api)); - let mut controller = PeerController::new(wg_api.clone(), wireguard_data.peer_rx); + let mut controller = PeerController::new( + wg_api.clone(), + interface_config.peers, + wireguard_data.peer_rx, + control_tx, + ); tokio::spawn(async move { controller.run(task_client).await }); Ok(wg_api) } - -#[cfg(not(target_os = "linux"))] -pub async fn start_wireguard() { - todo!("WireGuard is currently only supported on Linux"); -} diff --git a/common/wireguard/src/peer_controller.rs b/common/wireguard/src/peer_controller.rs index a325cf29dbe..fa11aa8a71d 100644 --- a/common/wireguard/src/peer_controller.rs +++ b/common/wireguard/src/peer_controller.rs @@ -1,91 +1,182 @@ // Copyright 2024 - Nym Technologies SA // SPDX-License-Identifier: Apache-2.0 -use std::{ - sync::Arc, - time::{Duration, SystemTime}, -}; - -use defguard_wireguard_rs::{ - host::{Host, Peer}, - key::Key, - WGApi, WireguardInterfaceApi, -}; +use chrono::{Timelike, Utc}; +use defguard_wireguard_rs::{host::Peer, key::Key, WireguardInterfaceApi}; +use nym_wireguard_types::registration::{RemainingBandwidthData, BANDWIDTH_CAP_PER_DAY}; +use std::time::SystemTime; +use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::sync::mpsc; use tokio_stream::{wrappers::IntervalStream, StreamExt}; +use crate::error::Error; use crate::WgApiWrapper; -const DEFAULT_PEER_TIMEOUT: Duration = Duration::from_secs(60 * 60); // 1 hour +// To avoid any problems, keep this stale check time bigger (>2x) then the bandwidth cap +// reset time (currently that one is 24h, at UTC midnight) +const DEFAULT_PEER_TIMEOUT: Duration = Duration::from_secs(60 * 60 * 24 * 3); // 3 days const DEFAULT_PEER_TIMEOUT_CHECK: Duration = Duration::from_secs(60); // 1 minute -pub enum PeerControlMessage { +pub enum PeerControlRequest { AddPeer(Peer), RemovePeer(Key), + QueryBandwidth(Key), +} + +pub enum PeerControlResponse { + AddPeer { + success: bool, + }, + RemovePeer { + success: bool, + }, + QueryBandwidth { + bandwidth_data: Option, + }, } pub struct PeerController { - peer_rx: mpsc::UnboundedReceiver, + request_rx: mpsc::UnboundedReceiver, + response_tx: mpsc::UnboundedSender, wg_api: Arc, timeout_check_interval: IntervalStream, + active_peers: HashMap, + suspended_peers: HashMap, + last_seen_bandwidth: HashMap, } impl PeerController { pub fn new( wg_api: Arc, - peer_rx: mpsc::UnboundedReceiver, + peers: Vec, + request_rx: mpsc::UnboundedReceiver, + response_tx: mpsc::UnboundedSender, ) -> Self { let timeout_check_interval = tokio_stream::wrappers::IntervalStream::new( tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK), ); + let active_peers = peers + .into_iter() + .map(|peer| (peer.public_key.clone(), peer)) + .collect(); + PeerController { wg_api, - peer_rx, + request_rx, + response_tx, timeout_check_interval, + active_peers, + suspended_peers: HashMap::new(), + last_seen_bandwidth: HashMap::new(), } } - fn remove_stale_peers(wg_api: &WGApi, host: Host) { - let current_timestamp = SystemTime::now(); - for (key, peer) in host.peers.iter() { - if let Some(timestamp) = peer.last_handshake { - if let Ok(duration_since_handshake) = current_timestamp.duration_since(timestamp) { - if duration_since_handshake > DEFAULT_PEER_TIMEOUT { - if let Err(e) = wg_api.remove_peer(key) { - log::error!("Could not remove stale peer: {:?}", e); - } else { - log::debug!("Removed stale peer {:?}", key); - } - } + fn check_stale_peer(&self, peer: &Peer, current_timestamp: SystemTime) -> Result { + if let Some(timestamp) = peer.last_handshake { + if let Ok(duration_since_handshake) = current_timestamp.duration_since(timestamp) { + if duration_since_handshake > DEFAULT_PEER_TIMEOUT { + self.wg_api.inner.remove_peer(&peer.public_key)?; + return Ok(true); } } } + + Ok(false) + } + + fn check_suspend_peer(&mut self, peer: &Peer) -> Result<(), Error> { + let prev_peer = self + .active_peers + .get(&peer.public_key) + .ok_or(Error::PeerMismatch)?; + let data_usage = + (peer.rx_bytes + peer.tx_bytes).saturating_sub(prev_peer.rx_bytes + prev_peer.tx_bytes); + if data_usage > BANDWIDTH_CAP_PER_DAY { + self.wg_api.inner.remove_peer(&peer.public_key)?; + let (moved_key, moved_peer) = self + .active_peers + .remove_entry(&peer.public_key) + .ok_or(Error::PeerMismatch)?; + self.suspended_peers.insert(moved_key, moved_peer); + } + Ok(()) + } + + fn check_peers(&mut self) -> Result<(), Error> { + // Add 10 seconds to cover edge cases. At worst, we give ten free seconds worth of bandwidth + // by resetting the bandwidth twice + let reset = Utc::now().num_seconds_from_midnight() as u64 + <= DEFAULT_PEER_TIMEOUT_CHECK.as_secs() + 10; + + if reset { + for (_, peer) in self.suspended_peers.drain() { + self.wg_api.inner.configure_peer(&peer)?; + } + } + let host = self.wg_api.inner.read_interface_data()?; + self.last_seen_bandwidth = host + .peers + .iter() + .map(|(key, peer)| (key.clone(), peer.rx_bytes + peer.tx_bytes)) + .collect(); + if reset { + self.active_peers = host.peers; + } else { + let current_timestamp = SystemTime::now(); + for peer in host.peers.values() { + if !self.check_stale_peer(peer, current_timestamp)? { + self.check_suspend_peer(peer)?; + } + } + } + + Ok(()) } pub async fn run(&mut self, mut task_client: nym_task::TaskClient) { loop { tokio::select! { _ = self.timeout_check_interval.next() => { - match self.wg_api.inner.read_interface_data() { - Ok(host) => Self::remove_stale_peers(&self.wg_api.inner, host), - Err(e) => { log::error!("Could not read peer data: {:?}", e); }, + if let Err(e) = self.check_peers() { + log::error!("Error while periodically checking peers: {:?}", e); } } _ = task_client.recv() => { log::trace!("PeerController handler: Received shutdown"); break; } - msg = self.peer_rx.recv() => { + msg = self.request_rx.recv() => { match msg { - Some(PeerControlMessage::AddPeer(peer)) => { - if let Err(e) = self.wg_api.inner.configure_peer(&peer) { + Some(PeerControlRequest::AddPeer(peer)) => { + let success = if let Err(e) = self.wg_api.inner.configure_peer(&peer) { log::error!("Could not configure peer: {:?}", e); - } + false + } else { + self.active_peers.insert(peer.public_key.clone(), peer); + true + }; + self.response_tx.send(PeerControlResponse::AddPeer { success }).ok(); } - Some(PeerControlMessage::RemovePeer(peer_pubkey)) => { - if let Err(e) = self.wg_api.inner.remove_peer(&peer_pubkey) { + Some(PeerControlRequest::RemovePeer(peer_pubkey)) => { + let success = if let Err(e) = self.wg_api.inner.remove_peer(&peer_pubkey) { log::error!("Could not remove peer: {:?}", e); - } + false + } else { + self.active_peers.remove(&peer_pubkey); + self.suspended_peers.remove(&peer_pubkey); + true + }; + self.response_tx.send(PeerControlResponse::RemovePeer { success }).ok(); + } + Some(PeerControlRequest::QueryBandwidth(peer_pubkey)) => { + let msg = if self.suspended_peers.contains_key(&peer_pubkey) { + PeerControlResponse::QueryBandwidth { bandwidth_data: Some(RemainingBandwidthData{ available_bandwidth: 0, suspended: true }) } + } else if let Some(&consumed_bandwidth) = self.last_seen_bandwidth.get(&peer_pubkey) { + PeerControlResponse::QueryBandwidth { bandwidth_data: Some(RemainingBandwidthData{ available_bandwidth: BANDWIDTH_CAP_PER_DAY - consumed_bandwidth, suspended: false })} + } else { + PeerControlResponse::QueryBandwidth { bandwidth_data: None } + }; + self.response_tx.send(msg).ok(); } None => { log::trace!("PeerController [main loop]: stopping since channel closed"); diff --git a/gateway/src/node/mod.rs b/gateway/src/node/mod.rs index 3f8690dd79d..c17c9dd4a8d 100644 --- a/gateway/src/node/mod.rs +++ b/gateway/src/node/mod.rs @@ -257,6 +257,7 @@ impl Gateway { .ok_or(GatewayError::UnspecifiedAuthenticatorConfig)?; let (router_tx, mut router_rx) = oneshot::channel(); let (auth_mix_sender, auth_mix_receiver) = mpsc::unbounded(); + let (peer_response_tx, peer_response_rx) = tokio::sync::mpsc::unbounded_channel(); let router_shutdown = shutdown.fork("message_router"); let transceiver = LocalGateway::new( *self.identity_keypair.public_key(), @@ -269,6 +270,7 @@ impl Gateway { let mut authenticator_server = nym_authenticator::Authenticator::new( opts.config.clone(), wireguard_data.inner.clone(), + peer_response_rx, ) .with_custom_gateway_transceiver(Box::new(transceiver)) .with_shutdown(shutdown.fork("authenticator")) @@ -299,7 +301,8 @@ impl Gateway { MessageRouter::new(auth_mix_receiver, packet_router) .start_with_shutdown(router_shutdown); - let wg_api = nym_wireguard::start_wireguard(shutdown, wireguard_data).await?; + let wg_api = + nym_wireguard::start_wireguard(shutdown, wireguard_data, peer_response_tx).await?; Ok(StartedAuthenticator { wg_api, diff --git a/nym-node/src/node/mod.rs b/nym-node/src/node/mod.rs index 4bfbbe2febf..c851decf239 100644 --- a/nym-node/src/node/mod.rs +++ b/nym-node/src/node/mod.rs @@ -32,7 +32,7 @@ use nym_node_http_api::{NymNodeHTTPServer, NymNodeRouter}; use nym_sphinx_acknowledgements::AckKey; use nym_sphinx_addressing::Recipient; use nym_task::{TaskClient, TaskManager}; -use nym_wireguard::{peer_controller::PeerControlMessage, WireguardGatewayData}; +use nym_wireguard::{peer_controller::PeerControlRequest, WireguardGatewayData}; use rand::rngs::OsRng; use rand::{CryptoRng, RngCore}; use std::path::Path; @@ -264,7 +264,7 @@ impl ExitGatewayData { pub struct WireguardData { inner: WireguardGatewayData, - peer_rx: UnboundedReceiver, + peer_rx: UnboundedReceiver, } impl WireguardData { diff --git a/service-providers/authenticator/src/authenticator.rs b/service-providers/authenticator/src/authenticator.rs index e8faf1b3715..04e77896a22 100644 --- a/service-providers/authenticator/src/authenticator.rs +++ b/service-providers/authenticator/src/authenticator.rs @@ -8,7 +8,8 @@ use ipnetwork::IpNetwork; use nym_client_core::{HardcodedTopologyProvider, TopologyProvider}; use nym_sdk::{mixnet::Recipient, GatewayTransceiver}; use nym_task::{TaskClient, TaskHandle}; -use nym_wireguard::WireguardGatewayData; +use nym_wireguard::{peer_controller::PeerControlResponse, WireguardGatewayData}; +use tokio::sync::mpsc::UnboundedReceiver; use crate::{config::Config, error::AuthenticatorError}; @@ -30,18 +31,24 @@ pub struct Authenticator { custom_topology_provider: Option>, custom_gateway_transceiver: Option>, wireguard_gateway_data: WireguardGatewayData, + response_rx: UnboundedReceiver, shutdown: Option, on_start: Option>, } impl Authenticator { - pub fn new(config: Config, wireguard_gateway_data: WireguardGatewayData) -> Self { + pub fn new( + config: Config, + wireguard_gateway_data: WireguardGatewayData, + response_rx: UnboundedReceiver, + ) -> Self { Self { config, wait_for_gateway: false, custom_topology_provider: None, custom_gateway_transceiver: None, wireguard_gateway_data, + response_rx, shutdown: None, on_start: None, } @@ -129,6 +136,7 @@ impl Authenticator { self.config, private_ip_network, self.wireguard_gateway_data, + self.response_rx, mixnet_client, task_handle, ); diff --git a/service-providers/authenticator/src/cli/peer_handler.rs b/service-providers/authenticator/src/cli/peer_handler.rs index 599941c5fb9..b4d6e1faea6 100644 --- a/service-providers/authenticator/src/cli/peer_handler.rs +++ b/service-providers/authenticator/src/cli/peer_handler.rs @@ -2,21 +2,24 @@ // SPDX-License-Identifier: Apache-2.0 use nym_sdk::TaskClient; -use nym_wireguard::peer_controller::PeerControlMessage; +use nym_wireguard::peer_controller::{PeerControlRequest, PeerControlResponse}; use tokio::sync::mpsc; pub struct DummyHandler { - peer_rx: mpsc::UnboundedReceiver, + peer_rx: mpsc::UnboundedReceiver, + response_tx: mpsc::UnboundedSender, task_client: TaskClient, } impl DummyHandler { pub fn new( - peer_rx: mpsc::UnboundedReceiver, + peer_rx: mpsc::UnboundedReceiver, + response_tx: mpsc::UnboundedSender, task_client: TaskClient, ) -> Self { DummyHandler { peer_rx, + response_tx, task_client, } } @@ -27,11 +30,17 @@ impl DummyHandler { msg = self.peer_rx.recv() => { if let Some(msg) = msg { match msg { - PeerControlMessage::AddPeer(peer) => { + PeerControlRequest::AddPeer(peer) => { log::info!("[DUMMY] Adding peer {:?}", peer); + self.response_tx.send(PeerControlResponse::AddPeer { success: true }).ok(); } - PeerControlMessage::RemovePeer(key) => { + PeerControlRequest::RemovePeer(key) => { log::info!("[DUMMY] Removing peer {:?}", key); + self.response_tx.send(PeerControlResponse::RemovePeer { success: true }).ok(); + } + PeerControlRequest::QueryBandwidth(key) => { + log::info!("[DUMMY] Querying bandwidth for peer {:?}", key); + self.response_tx.send(PeerControlResponse::QueryBandwidth { bandwidth_data: None }).ok(); } } } else { diff --git a/service-providers/authenticator/src/cli/run.rs b/service-providers/authenticator/src/cli/run.rs index e1191449874..0abaa4b3cde 100644 --- a/service-providers/authenticator/src/cli/run.rs +++ b/service-providers/authenticator/src/cli/run.rs @@ -14,6 +14,7 @@ use nym_crypto::asymmetric::x25519::KeyPair; use nym_task::TaskHandle; use nym_wireguard::WireguardGatewayData; use rand::rngs::OsRng; +use tokio::sync::mpsc::unbounded_channel; #[allow(clippy::struct_excessive_bools)] #[derive(Args, Clone)] @@ -48,11 +49,13 @@ pub(crate) async fn execute(args: &Run) -> Result<(), AuthenticatorError> { Arc::new(KeyPair::new(&mut OsRng)), ); let task_handler = TaskHandle::default(); - let handler = DummyHandler::new(peer_rx, task_handler.fork("peer-handler")); + let (response_tx, response_rx) = unbounded_channel(); + let handler = DummyHandler::new(peer_rx, response_tx, task_handler.fork("peer-handler")); tokio::spawn(async move { handler.run().await; }); - let mut server = nym_authenticator::Authenticator::new(config, wireguard_gateway_data); + let mut server = + nym_authenticator::Authenticator::new(config, wireguard_gateway_data, response_rx); if let Some(custom_mixnet) = &args.common_args.custom_mixnet { server = server.with_stored_topology(custom_mixnet)? } diff --git a/service-providers/authenticator/src/mixnet_listener.rs b/service-providers/authenticator/src/mixnet_listener.rs index 355a92f28c6..a93b2185ef2 100644 --- a/service-providers/authenticator/src/mixnet_listener.rs +++ b/service-providers/authenticator/src/mixnet_listener.rs @@ -17,12 +17,13 @@ use nym_authenticator_requests::v1::{ use nym_sdk::mixnet::{InputMessage, MixnetMessageSender, Recipient, TransmissionLane}; use nym_sphinx::receiver::ReconstructedMessage; use nym_task::TaskHandle; -use nym_wireguard::WireguardGatewayData; +use nym_wireguard::{peer_controller::PeerControlResponse, WireguardGatewayData}; use nym_wireguard_types::{ - registration::{PendingRegistrations, PrivateIPs, RegistrationData}, + registration::{PendingRegistrations, PrivateIPs, RegistrationData, RegistredData}, GatewayClient, InitMessage, PeerPublicKey, }; use rand::{prelude::IteratorRandom, thread_rng}; +use tokio::sync::mpsc::UnboundedReceiver; use tokio_stream::wrappers::IntervalStream; use crate::{config::Config, error::*}; @@ -45,6 +46,8 @@ pub(crate) struct MixnetListener { pub(crate) wireguard_gateway_data: WireguardGatewayData, + pub(crate) response_rx: UnboundedReceiver, + pub(crate) free_private_network_ips: Arc, pub(crate) timeout_check_interval: IntervalStream, @@ -55,6 +58,7 @@ impl MixnetListener { config: Config, private_ip_network: IpNetwork, wireguard_gateway_data: WireguardGatewayData, + response_rx: UnboundedReceiver, mixnet_client: nym_sdk::mixnet::MixnetClient, task_handle: TaskHandle, ) -> Self { @@ -66,6 +70,7 @@ impl MixnetListener { task_handle, registration_in_progres: Default::default(), wireguard_gateway_data, + response_rx, free_private_network_ips: Arc::new( private_ip_network.iter().map(|ip| (ip, None)).collect(), ), @@ -73,22 +78,6 @@ impl MixnetListener { } } - fn remove_from_registry( - &self, - remote_public: &PeerPublicKey, - gateway_client: &GatewayClient, - ) -> Result<()> { - self.wireguard_gateway_data - .remove_peer(gateway_client) - .map_err(|err| { - AuthenticatorError::InternalError(format!("could not remove peer: {:?}", err)) - })?; - self.wireguard_gateway_data - .client_registry() - .remove(remote_public); - Ok(()) - } - fn remove_stale_registrations(&self) -> Result<()> { for reg in self.registration_in_progres.iter().map(|reg| reg.clone()) { let mut ip = self @@ -121,7 +110,7 @@ impl MixnetListener { Ok(()) } - fn on_initial_request( + async fn on_initial_request( &mut self, init_message: InitMessage, request_id: u64, @@ -136,24 +125,26 @@ impl MixnetListener { reply_to, )); } - let gateway_client_opt = if let Some(gateway_client) = self + if let Some(gateway_client) = self .wireguard_gateway_data .client_registry() .get(&remote_public) { - let mut private_ip_ref = self - .free_private_network_ips - .get_mut(&gateway_client.private_ip) - .ok_or(AuthenticatorError::InternalError(String::from( - "could not find private IP", - )))?; - *private_ip_ref = None; - Some(gateway_client.clone()) - } else { - None - }; - if let Some(gateway_client) = gateway_client_opt { - self.remove_from_registry(&remote_public, &gateway_client)?; + return Ok(AuthenticatorResponse::new_registered( + RegistredData { + pub_key: PeerPublicKey::new( + self.wireguard_gateway_data + .keypair() + .public_key() + .to_bytes() + .into(), + ), + private_ip: gateway_client.private_ip, + wg_port: self.config.authenticator.announced_port, + }, + reply_to, + request_id, + )); } let mut private_ip_ref = self .free_private_network_ips @@ -184,7 +175,7 @@ impl MixnetListener { )) } - fn on_final_request( + async fn on_final_request( &mut self, gateway_client: GatewayClient, request_id: u64, @@ -209,18 +200,77 @@ impl MixnetListener { .map_err(|err| { AuthenticatorError::InternalError(format!("could not add peer: {:?}", err)) })?; + + let PeerControlResponse::AddPeer { success } = + self.response_rx + .recv() + .await + .ok_or(AuthenticatorError::InternalError( + "no response for add peer".to_string(), + ))? + else { + return Err(AuthenticatorError::InternalError( + "unexpected response type".to_string(), + )); + }; + if !success { + return Err(AuthenticatorError::InternalError( + "adding peer could not be performed".to_string(), + )); + } self.registration_in_progres .remove(&gateway_client.pub_key()); self.wireguard_gateway_data .client_registry() .insert(gateway_client.pub_key(), gateway_client); - Ok(AuthenticatorResponse::new_registered(reply_to, request_id)) + Ok(AuthenticatorResponse::new_registered( + RegistredData { + pub_key: registration_data.gateway_data.pub_key, + private_ip: registration_data.gateway_data.private_ip, + wg_port: registration_data.wg_port, + }, + reply_to, + request_id, + )) } else { Err(AuthenticatorError::MacVerificationFailure) } } + async fn on_query_bandwidth_request( + &mut self, + peer_public_key: PeerPublicKey, + request_id: u64, + reply_to: Recipient, + ) -> AuthenticatorHandleResult { + self.wireguard_gateway_data + .query_bandwidth(peer_public_key) + .map_err(|err| { + AuthenticatorError::InternalError(format!( + "could not query peer bandwidth: {:?}", + err + )) + })?; + let PeerControlResponse::QueryBandwidth { bandwidth_data } = self + .response_rx + .recv() + .await + .ok_or(AuthenticatorError::InternalError( + "no response for query".to_string(), + ))? + else { + return Err(AuthenticatorError::InternalError( + "unexpected response type".to_string(), + )); + }; + Ok(AuthenticatorResponse::new_remaining_bandwidth( + bandwidth_data, + reply_to, + request_id, + )) + } + async fn on_reconstructed_message( &mut self, reconstructed: ReconstructedMessage, @@ -240,9 +290,19 @@ impl MixnetListener { match request.data { AuthenticatorRequestData::Initial(init_msg) => { self.on_initial_request(init_msg, request.request_id, request.reply_to) + .await } AuthenticatorRequestData::Final(client) => { self.on_final_request(client, request.request_id, request.reply_to) + .await + } + AuthenticatorRequestData::QueryBandwidth(peer_public_key) => { + self.on_query_bandwidth_request( + peer_public_key, + request.request_id, + request.reply_to, + ) + .await } } }