diff --git a/Cargo.lock b/Cargo.lock index ca6f517837a..c536f31a218 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6410,6 +6410,7 @@ dependencies = [ "thiserror 2.0.12", "time", "tokio", + "tokio-stream", "tokio-util", "toml 0.8.23", "tower-http 0.5.2", @@ -7157,9 +7158,11 @@ dependencies = [ name = "nym-task" version = "0.1.0" dependencies = [ + "anyhow", "cfg-if", "futures", "log", + "nym-test-utils", "thiserror 2.0.12", "tokio", "tokio-util", diff --git a/clients/native/src/client/mod.rs b/clients/native/src/client/mod.rs index ce36a5eb849..ffe4daa9fb8 100644 --- a/clients/native/src/client/mod.rs +++ b/clients/native/src/client/mod.rs @@ -11,7 +11,7 @@ use nym_client_core::client::base_client::{ BaseClientBuilder, ClientInput, ClientOutput, ClientState, }; use nym_sphinx::params::PacketType; -use nym_task::TaskHandle; +use nym_task::ShutdownManager; use nym_validator_client::QueryHttpRpcNyxdClient; use std::error::Error; use std::path::PathBuf; @@ -29,6 +29,8 @@ pub struct SocketClient { /// Optional path to a .json file containing standalone network details. custom_mixnet: Option, + + shutdown_manager: ShutdownManager, } impl SocketClient { @@ -40,6 +42,7 @@ impl SocketClient { SocketClient { config, custom_mixnet, + shutdown_manager: Default::default(), } } @@ -49,7 +52,7 @@ impl SocketClient { client_output: ClientOutput, client_state: ClientState, self_address: &Recipient, - task_client: nym_task::TaskClient, + shutdown_token: nym_task::ShutdownToken, packet_type: PacketType, ) { info!("Starting websocket listener..."); @@ -77,24 +80,24 @@ impl SocketClient { shared_lane_queue_lengths, reply_controller_sender, Some(packet_type), - task_client.fork("websocket_handler"), + shutdown_token.clone(), ); websocket::Listener::new( config.socket.host, config.socket.listening_port, - task_client.with_suffix("websocket_listener"), + shutdown_token.child_token(), ) .start(websocket_handler); } /// blocking version of `start_socket` method. Will run forever (or until SIGINT is sent) pub async fn run_socket_forever(self) -> Result<(), Box> { - let shutdown = self.start_socket().await?; + let mut shutdown = self.start_socket().await?; - let res = shutdown.wait_for_shutdown().await; + shutdown.run_until_shutdown().await; log::info!("Stopping nym-client"); - res + Ok(()) } async fn initialise_storage(&self) -> Result { @@ -119,6 +122,7 @@ impl SocketClient { let mut base_client = BaseClientBuilder::new(self.config().base(), storage, dkg_query_client) + .with_shutdown(self.shutdown_manager.shutdown_tracker_owned()) .with_user_agent(user_agent); if let Some(custom_mixnet) = &self.custom_mixnet { @@ -128,7 +132,7 @@ impl SocketClient { Ok(base_client) } - pub async fn start_socket(self) -> Result { + pub async fn start_socket(self) -> Result { if !self.config.socket.socket_type.is_websocket() { return Err(ClientError::InvalidSocketMode); } @@ -147,13 +151,13 @@ impl SocketClient { client_output, client_state, &self_address, - started_client.task_handle.get_handle(), + self.shutdown_manager.child_shutdown_token(), packet_type, ); info!("Client startup finished!"); info!("The address of this client is: {self_address}"); - Ok(started_client.task_handle) + Ok(self.shutdown_manager) } } diff --git a/clients/native/src/websocket/handler.rs b/clients/native/src/websocket/handler.rs index 761e2b9531c..2ea7cac609d 100644 --- a/clients/native/src/websocket/handler.rs +++ b/clients/native/src/websocket/handler.rs @@ -19,7 +19,7 @@ use nym_sphinx::receiver::ReconstructedMessage; use nym_task::connections::{ ConnectionCommand, ConnectionCommandSender, ConnectionId, LaneQueueLengths, TransmissionLane, }; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::time::Duration; use tokio::net::TcpStream; use tokio::time::Instant; @@ -44,7 +44,7 @@ pub(crate) struct HandlerBuilder { lane_queue_lengths: LaneQueueLengths, reply_controller_sender: ReplyControllerSender, packet_type: Option, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl HandlerBuilder { @@ -57,7 +57,7 @@ impl HandlerBuilder { lane_queue_lengths: LaneQueueLengths, reply_controller_sender: ReplyControllerSender, packet_type: Option, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { Self { msg_input, @@ -67,14 +67,13 @@ impl HandlerBuilder { lane_queue_lengths, reply_controller_sender, packet_type, - task_client, + shutdown_token, } } // TODO: make sure we only ever have one active handler pub fn create_active_handler(&self) -> Handler { - let mut task_client = self.task_client.fork("active_handler"); - task_client.disarm(); + let shutdown_token = self.shutdown_token.clone(); Handler { msg_input: self.msg_input.clone(), client_connection_tx: self.client_connection_tx.clone(), @@ -85,7 +84,7 @@ impl HandlerBuilder { lane_queue_lengths: self.lane_queue_lengths.clone(), reply_controller_sender: self.reply_controller_sender.clone(), packet_type: self.packet_type, - task_client, + shutdown_token, } } } @@ -100,19 +99,14 @@ pub(crate) struct Handler { lane_queue_lengths: LaneQueueLengths, reply_controller_sender: ReplyControllerSender, packet_type: Option, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl Drop for Handler { fn drop(&mut self) { - if let Err(err) = self + let _ = self .buffer_requester - .unbounded_send(ReceivedBufferMessage::ReceiverDisconnect) - { - if !self.task_client.is_shutdown_poll() { - error!("failed to disconnect the receiver from the buffer: {err}"); - } - } + .unbounded_send(ReceivedBufferMessage::ReceiverDisconnect); } } @@ -142,7 +136,7 @@ impl Handler { { Ok(length) => length, Err(err) => { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!( "Failed to get reply queue length for connection {connection_id}: {err}" ); @@ -192,7 +186,7 @@ impl Handler { // the ack control is now responsible for chunking, etc. let input_msg = InputMessage::new_regular(recipient, message, lane, self.packet_type); if let Err(err) = self.msg_input.send(input_msg).await { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("Failed to send message to the input buffer: {err}"); } } @@ -225,7 +219,7 @@ impl Handler { let input_msg = InputMessage::new_anonymous(recipient, message, reply_surbs, lane, self.packet_type); if let Err(err) = self.msg_input.send(input_msg).await { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("Failed to send anonymous message to the input buffer: {err}"); } } @@ -253,7 +247,7 @@ impl Handler { let input_msg = InputMessage::new_reply(recipient_tag, message, lane, self.packet_type); if let Err(err) = self.msg_input.send(input_msg).await { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("Failed to send reply message to the input buffer: {err}"); } } @@ -275,7 +269,7 @@ impl Handler { .client_connection_tx .unbounded_send(ConnectionCommand::Close(connection_id)) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("Failed to send close connection command: {err}"); } } @@ -394,11 +388,14 @@ impl Handler { } async fn listen_for_requests(&mut self, mut msg_receiver: ReconstructedMessagesReceiver) { - let mut task_client = self.task_client.fork("select"); - task_client.disarm(); + let shutdown_token = self.shutdown_token.clone(); - while !task_client.is_shutdown() { + loop { tokio::select! { + _ = shutdown_token.cancelled() => { + log::trace!("Websocket handler: Received shutdown"); + break; + } // we can either get a client request from the websocket socket_msg = self.next_websocket_request() => { if socket_msg.is_none() { @@ -436,9 +433,6 @@ impl Handler { break; } } - _ = task_client.recv() => { - log::trace!("Websocket handler: Received shutdown"); - } } } log::debug!("Websocket handler: Exiting"); @@ -464,7 +458,7 @@ impl Handler { reconstructed_sender, )) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("failed to announce the receiver to the buffer: {err}"); } } diff --git a/clients/native/src/websocket/listener.rs b/clients/native/src/websocket/listener.rs index a1b430a9301..f0723c18241 100644 --- a/clients/native/src/websocket/listener.rs +++ b/clients/native/src/websocket/listener.rs @@ -3,7 +3,7 @@ use super::handler::HandlerBuilder; use log::*; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::net::IpAddr; use std::{net::SocketAddr, process, sync::Arc}; use tokio::io::AsyncWriteExt; @@ -23,15 +23,15 @@ impl State { pub(crate) struct Listener { address: SocketAddr, state: State, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl Listener { - pub(crate) fn new(host: IpAddr, port: u16, task_client: TaskClient) -> Self { + pub(crate) fn new(host: IpAddr, port: u16, shutdown_token: ShutdownToken) -> Self { Listener { address: SocketAddr::new(host, port), state: State::AwaitingConnection, - task_client, + shutdown_token, } } @@ -46,11 +46,11 @@ impl Listener { let notify = Arc::new(Notify::new()); - while !self.task_client.is_shutdown() { + while !self.shutdown_token.is_cancelled() { tokio::select! { // When the handler finishes we check if shutdown is signalled _ = notify.notified() => { - if self.task_client.is_shutdown() { + if self.shutdown_token.is_cancelled() { log::trace!("Websocket listener: detected shutdown after connection closed"); break; } @@ -59,7 +59,7 @@ impl Listener { } // ... but when there is no connected client at the time of shutdown being // signalled, we handle it here. - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { if !self.state.is_connected() { log::trace!("Not connected: shutting down"); break; diff --git a/common/client-core/src/client/base_client/helpers.rs b/common/client-core/src/client/base_client/helpers.rs index 4259ead0462..68c0cb61bcb 100644 --- a/common/client-core/src/client/base_client/helpers.rs +++ b/common/client-core/src/client/base_client/helpers.rs @@ -1,7 +1,9 @@ // Copyright 2023 - Nym Technologies SA // SPDX-License-Identifier: Apache-2.0 +use crate::error::ClientCoreError; use crate::{client::replies::reply_storage, config::DebugConfig}; +use nym_task::{ShutdownManager, ShutdownToken, ShutdownTracker}; pub fn setup_empty_reply_surb_backend(debug_config: &DebugConfig) -> reply_storage::Empty { reply_storage::Empty { @@ -13,3 +15,49 @@ pub fn setup_empty_reply_surb_backend(debug_config: &DebugConfig) -> reply_stora .maximum_reply_surb_storage_threshold, } } + +// old 'TaskHandle' +pub(crate) enum ShutdownHelper { + Internal(ShutdownManager), + External(ShutdownTracker), +} + +fn new_shutdown_manager() -> Result { + cfg_if::cfg_if! { + if #[cfg(not(target_arch = "wasm32"))] { + Ok(ShutdownManager::build_new_default()?) + } else { + Ok(ShutdownManager::new_without_signals()) + } + } +} + +impl ShutdownHelper { + pub(crate) fn new(shutdown_tracker: Option) -> Result { + match shutdown_tracker { + None => Ok(ShutdownHelper::Internal(new_shutdown_manager()?)), + Some(shutdown_tracker) => Ok(ShutdownHelper::External(shutdown_tracker)), + } + } + + pub(crate) fn into_internal(self) -> Option { + match self { + ShutdownHelper::Internal(manager) => Some(manager), + ShutdownHelper::External(_) => None, + } + } + + pub(crate) fn shutdown_token(&self) -> ShutdownToken { + match self { + ShutdownHelper::External(shutdown) => shutdown.clone_shutdown_token(), + ShutdownHelper::Internal(shutdown) => shutdown.clone_shutdown_token(), + } + } + + pub(crate) fn tracker(&self) -> &ShutdownTracker { + match self { + ShutdownHelper::External(shutdown) => shutdown, + ShutdownHelper::Internal(shutdown) => shutdown.shutdown_tracker(), + } + } +} diff --git a/common/client-core/src/client/base_client/mod.rs b/common/client-core/src/client/base_client/mod.rs index 36cb427780a..206b682fd19 100644 --- a/common/client-core/src/client/base_client/mod.rs +++ b/common/client-core/src/client/base_client/mod.rs @@ -4,6 +4,7 @@ use super::mix_traffic::ClientRequestSender; use super::received_buffer::ReceivedBufferMessage; use super::statistics_control::StatisticsControl; +use crate::client::base_client::helpers::ShutdownHelper; use crate::client::base_client::storage::helpers::store_client_keys; use crate::client::base_client::storage::MixnetClientStorage; use crate::client::cover_traffic_stream::LoopCoverTrafficStream; @@ -27,13 +28,13 @@ use crate::client::topology_control::nym_api_provider::NymApiTopologyProvider; use crate::client::topology_control::{ TopologyAccessor, TopologyRefresher, TopologyRefresherConfig, }; +use crate::config; use crate::config::{Config, DebugConfig}; use crate::error::ClientCoreError; use crate::init::{ setup_gateway, types::{GatewaySetup, InitialisationResult}, }; -use crate::{config, spawn_future}; use futures::channel::mpsc; use nym_bandwidth_controller::BandwidthController; use nym_client_core_config_types::{ForgetMe, RememberMe}; @@ -48,12 +49,11 @@ use nym_gateway_client::{ use nym_sphinx::acknowledgements::AckKey; use nym_sphinx::addressing::clients::Recipient; use nym_sphinx::addressing::nodes::NodeIdentity; -use nym_sphinx::params::PacketType; use nym_sphinx::receiver::{ReconstructedMessage, SphinxMessageReceiver}; use nym_statistics_common::clients::ClientStatsSender; use nym_statistics_common::generate_client_stats_id; use nym_task::connections::{ConnectionCommandReceiver, ConnectionCommandSender, LaneQueueLengths}; -use nym_task::{TaskClient, TaskHandle}; +use nym_task::{ShutdownManager, ShutdownTracker}; use nym_topology::provider_trait::TopologyProvider; use nym_topology::HardcodedTopologyProvider; use nym_validator_client::nym_api::NymApiClientExt; @@ -95,7 +95,6 @@ impl ClientInput { } } -#[derive(Clone)] pub struct ClientOutput { pub received_buffer_request_sender: ReceivedBufferRequestSender, } @@ -195,7 +194,7 @@ pub struct BaseClientBuilder { wait_for_gateway: bool, custom_topology_provider: Option>, custom_gateway_transceiver: Option>, - shutdown: Option, + shutdown: Option, user_agent: Option, setup_method: GatewaySetup, @@ -281,7 +280,7 @@ where } #[must_use] - pub fn with_shutdown(mut self, shutdown: TaskClient) -> Self { + pub fn with_shutdown(mut self, shutdown: ShutdownTracker) -> Self { self.shutdown = Some(shutdown); self } @@ -325,11 +324,11 @@ where topology_accessor: TopologyAccessor, mix_tx: BatchMixMessageSender, stats_tx: ClientStatsSender, - task_client: TaskClient, + shutdown_tracker: &ShutdownTracker, ) { info!("Starting loop cover traffic stream..."); - let stream = LoopCoverTrafficStream::new( + let mut stream = LoopCoverTrafficStream::new( ack_key, debug_config.acknowledgements.average_ack_delay, mix_tx, @@ -338,10 +337,9 @@ where debug_config.traffic, debug_config.cover_traffic, stats_tx, - task_client, ); - - stream.start(); + shutdown_tracker + .try_spawn_named_with_shutdown(async move { stream.run().await }, "CoverTrafficStream"); } #[allow(clippy::too_many_arguments)] @@ -357,13 +355,12 @@ where reply_controller_receiver: ReplyControllerReceiver, lane_queue_lengths: LaneQueueLengths, client_connection_rx: ConnectionCommandReceiver, - task_client: TaskClient, - packet_type: PacketType, stats_tx: ClientStatsSender, + shutdown_tracker: &ShutdownTracker, ) { info!("Starting real traffic stream..."); - RealMessagesController::new( + let real_messages_controller = RealMessagesController::new( controller_config, key_rotation_config, ack_receiver, @@ -376,9 +373,63 @@ where lane_queue_lengths, client_connection_rx, stats_tx, - task_client, - ) - .start(packet_type); + shutdown_tracker.clone_shutdown_token(), + ); + + // break out all the subtasks + let (mut out_queue_control, mut reply_controller, ack_controller) = + real_messages_controller.into_tasks(); + let ( + mut ack_listener, + mut input_listener, + mut retransmission_listener, + mut sent_notification_listener, + mut ack_action_controller, + ) = ack_controller.into_tasks(); + + shutdown_tracker.try_spawn_named( + async move { out_queue_control.run().await }, + "RealMessagesController::OutQueueControl", + ); + + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( + async move { reply_controller.run(shutdown_token).await }, + "RealMessagesController::ReplyController", + ); + + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( + async move { ack_listener.run(shutdown_token).await }, + "AcknowledgementController::AcknowledgementListener", + ); + + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( + async move { input_listener.run(shutdown_token).await }, + "AcknowledgementController::InputMessageListener", + ); + + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( + async move { retransmission_listener.run(shutdown_token).await }, + "AcknowledgementController::RetransmissionRequestListener", + ); + + shutdown_tracker.try_spawn_named_with_shutdown( + async move { + sent_notification_listener.run().await; + }, + "AcknowledgementController::SentNotificationListener", + ); + + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( + async move { ack_action_controller.run(shutdown_token).await }, + "AcknowledgementController::ActionController", + ); + + // .start(packet_type); } // buffer controlling all messages fetched from provider @@ -389,21 +440,29 @@ where mixnet_receiver: MixnetMessageReceiver, reply_key_storage: SentReplyKeys, reply_controller_sender: ReplyControllerSender, - shutdown: TaskClient, metrics_reporter: ClientStatsSender, + shutdown_tracker: &ShutdownTracker, ) { info!("Starting received messages buffer controller..."); - let controller: ReceivedMessagesBufferController = - ReceivedMessagesBufferController::new( - local_encryption_keypair, - query_receiver, - mixnet_receiver, - reply_key_storage, - reply_controller_sender, - metrics_reporter, - shutdown, - ); - controller.start() + let controller = ReceivedMessagesBufferController::::new( + local_encryption_keypair, + query_receiver, + mixnet_receiver, + reply_key_storage, + reply_controller_sender, + metrics_reporter, + shutdown_tracker.clone_shutdown_token(), + ); + let (mut msg_receiver, mut req_receiver) = controller.into_tasks(); + + shutdown_tracker.try_spawn_named( + async move { msg_receiver.run().await }, + "ReceivedMessagesBufferController::FragmentedMessageReceiver", + ); + shutdown_tracker.try_spawn_named( + async move { req_receiver.run().await }, + "ReceivedMessagesBufferController::RequestReceiver", + ); } #[allow(clippy::too_many_arguments)] @@ -415,7 +474,7 @@ where packet_router: PacketRouter, stats_reporter: ClientStatsSender, #[cfg(unix)] connection_fd_callback: Option>, - shutdown: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> Result, ClientCoreError> where ::StorageError: Send + Sync + 'static, @@ -434,7 +493,7 @@ where packet_router, bandwidth_controller, stats_reporter, - shutdown, + shutdown_tracker.clone_shutdown_token(), ) } else { let cfg = GatewayConfig::new( @@ -459,7 +518,7 @@ where stats_reporter, #[cfg(unix)] connection_fd_callback, - shutdown, + shutdown_tracker.clone_shutdown_token(), ) }; @@ -522,7 +581,7 @@ where packet_router: PacketRouter, stats_reporter: ClientStatsSender, #[cfg(unix)] connection_fd_callback: Option>, - mut shutdown: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> Result, ClientCoreError> where ::StorageError: Send + Sync + 'static, @@ -539,7 +598,6 @@ where Err(ClientCoreError::CustomGatewaySelectionExpected) } else { // and make sure to invalidate the task client, so we wouldn't cause premature shutdown - shutdown.disarm(); custom_gateway_transceiver.set_packet_router(packet_router)?; Ok(custom_gateway_transceiver) }; @@ -555,7 +613,7 @@ where stats_reporter, #[cfg(unix)] connection_fd_callback, - shutdown, + shutdown_tracker, ) .await?; @@ -586,22 +644,20 @@ where topology_accessor: TopologyAccessor, local_gateway: NodeIdentity, wait_for_gateway: bool, - mut task_client: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> Result<(), ClientCoreError> { let topology_refresher_config = TopologyRefresherConfig::new(topology_config.topology_refresh_rate); if topology_config.disable_refreshing { // if we're not spawning the refresher, don't cause shutdown immediately - info!("The background topology refesher is not going to be started"); - task_client.disarm(); + info!("The background topology refresher is not going to be started"); } let mut topology_refresher = TopologyRefresher::new( topology_refresher_config, topology_accessor, topology_provider, - task_client, ); // before returning, block entire runtime to refresh the current network view so that any // components depending on topology would see a non-empty view @@ -646,7 +702,10 @@ where // don't spawn the refresher if we don't want to be refreshing the topology. // only use the initial values obtained info!("Starting topology refresher..."); - topology_refresher.start(); + shutdown_tracker.try_spawn_named_with_shutdown( + async move { topology_refresher.run().await }, + "TopologyRefresher", + ); } Ok(()) @@ -657,7 +716,7 @@ where user_agent: Option, client_stats_id: String, input_sender: Sender, - task_client: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> ClientStatsSender { info!("Starting statistics control..."); StatisticsControl::create_and_start( @@ -667,18 +726,23 @@ where .unwrap_or("unknown".to_string()), client_stats_id, input_sender.clone(), - task_client, + shutdown_tracker, ) } fn start_mix_traffic_controller( gateway_transceiver: Box, - shutdown: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> (BatchMixMessageSender, ClientRequestSender) { info!("Starting mix traffic controller..."); - let (mix_traffic_controller, mix_tx, client_tx) = - MixTrafficController::new(gateway_transceiver, shutdown); - mix_traffic_controller.start(); + let (mut mix_traffic_controller, mix_tx, client_tx) = + MixTrafficController::new(gateway_transceiver, shutdown_tracker.clone_shutdown_token()); + + shutdown_tracker.try_spawn_named( + async move { mix_traffic_controller.run().await }, + "MixTrafficController", + ); + (mix_tx, client_tx) } @@ -686,7 +750,7 @@ where async fn setup_persistent_reply_storage( backend: S::ReplyStore, key_rotation_config: KeyRotationConfig, - shutdown: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> Result where ::StorageError: Sync + Send, @@ -711,13 +775,14 @@ where })?; let store_clone = mem_store.clone(); - spawn_future!( + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( async move { persistent_storage - .flush_on_shutdown(store_clone, shutdown) + .flush_on_shutdown(store_clone, shutdown_token) .await }, - "PersistentReplyStorage::flush_on_shutdown" + "PersistentReplyStorage::flush_on_shutdown", ); Ok(mem_store) @@ -809,11 +874,7 @@ where TopologyAccessor::new(self.config.debug.topology.ignore_egress_epoch_role); // Shutdown notifier for signalling tasks to stop - let shutdown = self - .shutdown - .map(Into::::into) - .unwrap_or_default() - .name_if_unnamed("BaseNymClient"); + let shutdown = ShutdownHelper::new(self.shutdown)?; // channels responsible for dealing with reply-related fun let (reply_controller_sender, reply_controller_receiver) = @@ -845,7 +906,7 @@ where self.user_agent.clone(), generate_client_stats_id(*self_address.identity()), input_sender.clone(), - shutdown.fork("statistics_control"), + shutdown.tracker(), ); // needs to be started as the first thing to block if required waiting for the gateway @@ -855,14 +916,14 @@ where shared_topology_accessor.clone(), self_address.gateway(), self.wait_for_gateway, - shutdown.fork("topology_refresher"), + shutdown.tracker(), ) .await?; let gateway_packet_router = PacketRouter::new( ack_sender, mixnet_messages_sender, - shutdown.get_handle().named("gateway-packet-router"), + shutdown.shutdown_token(), ); let gateway_transceiver = Self::setup_gateway_transceiver( @@ -875,7 +936,7 @@ where stats_reporter.clone(), #[cfg(unix)] self.connection_fd_callback, - shutdown.fork("gateway_transceiver"), + shutdown.tracker(), ) .await?; let gateway_ws_fd = gateway_transceiver.ws_fd(); @@ -883,7 +944,7 @@ where let reply_storage = Self::setup_persistent_reply_storage( reply_storage_backend, key_rotation_config, - shutdown.fork("persistent_reply_storage"), + shutdown.tracker(), ) .await?; @@ -893,8 +954,8 @@ where mixnet_messages_receiver, reply_storage.key_storage(), reply_controller_sender.clone(), - shutdown.fork("received_messages_buffer"), stats_reporter.clone(), + shutdown.tracker(), ); // The message_sender is the transmitter for any component generating sphinx packets @@ -902,10 +963,8 @@ where // traffic stream. // The MixTrafficController then sends the actual traffic - let (message_sender, client_request_sender) = Self::start_mix_traffic_controller( - gateway_transceiver, - shutdown.fork("mix_traffic_controller"), - ); + let (message_sender, client_request_sender) = + Self::start_mix_traffic_controller(gateway_transceiver, shutdown.tracker()); // Channels that the websocket listener can use to signal downstream to the real traffic // controller that connections are closed. @@ -933,9 +992,8 @@ where reply_controller_receiver, shared_lane_queue_lengths.clone(), client_connection_rx, - shutdown.fork("real_traffic_controller"), - self.config.debug.traffic.packet_type, stats_reporter.clone(), + shutdown.tracker(), ); if !self @@ -951,7 +1009,7 @@ where shared_topology_accessor.clone(), message_sender, stats_reporter.clone(), - shutdown.fork("cover_traffic_stream"), + shutdown.tracker(), ); } @@ -979,7 +1037,7 @@ where gateway_connection: GatewayConnection { gateway_ws_fd }, }, stats_reporter, - task_handle: shutdown, + shutdown_handle: shutdown.into_internal(), client_request_sender, forget_me: self.config.debug.forget_me, remember_me: self.config.debug.remember_me, @@ -995,7 +1053,7 @@ pub struct BaseClient { pub client_state: ClientState, pub stats_reporter: ClientStatsSender, pub client_request_sender: ClientRequestSender, - pub task_handle: TaskHandle, + pub shutdown_handle: Option, pub forget_me: ForgetMe, pub remember_me: RememberMe, } diff --git a/common/client-core/src/client/cover_traffic_stream.rs b/common/client-core/src/client/cover_traffic_stream.rs index d635fc20c60..d10a6b04dcd 100644 --- a/common/client-core/src/client/cover_traffic_stream.rs +++ b/common/client-core/src/client/cover_traffic_stream.rs @@ -3,7 +3,7 @@ use crate::client::mix_traffic::BatchMixMessageSender; use crate::client::topology_control::TopologyAccessor; -use crate::{config, spawn_future}; +use crate::config; use futures::task::{Context, Poll}; use futures::{Future, Stream, StreamExt}; use nym_sphinx::acknowledgements::AckKey; @@ -12,7 +12,6 @@ use nym_sphinx::cover::generate_loop_cover_packet; use nym_sphinx::params::{PacketSize, PacketType}; use nym_sphinx::utils::sample_poisson_duration; use nym_statistics_common::clients::{packet_statistics::PacketStatisticsEvent, ClientStatsSender}; -use nym_task::TaskClient; use rand::{rngs::OsRng, CryptoRng, Rng}; use std::pin::Pin; use std::sync::Arc; @@ -69,8 +68,6 @@ where packet_type: PacketType, stats_tx: ClientStatsSender, - - task_client: TaskClient, } impl Stream for LoopCoverTrafficStream @@ -117,7 +114,6 @@ impl LoopCoverTrafficStream { traffic_config: config::Traffic, cover_config: config::CoverTraffic, stats_tx: ClientStatsSender, - task_client: TaskClient, ) -> Self { let rng = OsRng; @@ -137,7 +133,6 @@ impl LoopCoverTrafficStream { use_legacy_sphinx_format: traffic_config.use_legacy_sphinx_format, packet_type: traffic_config.packet_type, stats_tx, - task_client, } } @@ -235,12 +230,13 @@ impl LoopCoverTrafficStream { tokio::task::yield_now().await; } + // it's fine if cover traffic stream task gets killed whilst processing next message #[allow(clippy::panic)] - pub fn start(mut self) { + pub async fn run(&mut self) { if self.cover_traffic.disable_loop_cover_traffic_stream { // we should have never got here in the first place - the task should have never been created to begin with // so panic and review the code that lead to this branch - panic!("attempted to start LoopCoverTrafficStream while config explicitly disabled it.") + panic!("attempted to run LoopCoverTrafficStream while config explicitly disabled it.") } // we should set initial delay only when we actually start the stream @@ -250,32 +246,11 @@ impl LoopCoverTrafficStream { ); self.set_next_delay(sampled); - let mut shutdown = self.task_client.fork("select"); - - spawn_future!( - async move { - debug!("Started LoopCoverTrafficStream with graceful shutdown support"); - - while !shutdown.is_shutdown() { - tokio::select! { - biased; - _ = shutdown.recv() => { - tracing::trace!("LoopCoverTrafficStream: Received shutdown"); - } - next = self.next() => { - if next.is_some() { - self.on_new_message().await; - } else { - tracing::trace!("LoopCoverTrafficStream: Stopping since channel closed"); - break; - } - } - } - } - shutdown.recv_timeout().await; - tracing::debug!("LoopCoverTrafficStream: Exiting"); - }, - "LoopCoverTrafficStream" - ) + while self.next().await.is_some() { + self.on_new_message().await; + } + + // this should never get triggered + error!("cover traffic stream has been exhausted!") } } diff --git a/common/client-core/src/client/mix_traffic/mod.rs b/common/client-core/src/client/mix_traffic/mod.rs index 7b705a7a3f2..33688e90672 100644 --- a/common/client-core/src/client/mix_traffic/mod.rs +++ b/common/client-core/src/client/mix_traffic/mod.rs @@ -2,11 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use crate::client::mix_traffic::transceiver::GatewayTransceiver; -use crate::error::ClientCoreError; -use crate::spawn_future; use nym_gateway_requests::ClientRequest; use nym_sphinx::forwarding::packet::MixPacket; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use tracing::*; use transceiver::ErasedGatewayError; @@ -34,13 +32,13 @@ pub struct MixTrafficController { // in long run `gateway_client` will be moved away from `MixTrafficController` anyway. consecutive_gateway_failure_count: usize, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl MixTrafficController { pub fn new( gateway_transceiver: T, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> ( MixTrafficController, BatchMixMessageSender, @@ -60,7 +58,7 @@ impl MixTrafficController { mix_rx: message_receiver, client_rx: client_receiver, consecutive_gateway_failure_count: 0, - task_client, + shutdown_token, }, message_sender, client_sender, @@ -69,7 +67,7 @@ impl MixTrafficController { pub fn new_dynamic( gateway_transceiver: Box, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> ( MixTrafficController, BatchMixMessageSender, @@ -84,7 +82,7 @@ impl MixTrafficController { mix_rx: message_receiver, client_rx: client_receiver, consecutive_gateway_failure_count: 0, - task_client, + shutdown_token, }, message_sender, client_sender, @@ -107,7 +105,7 @@ impl MixTrafficController { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { trace!("received shutdown while handling messages"); Ok(()) } @@ -127,7 +125,7 @@ impl MixTrafficController { async fn on_client_request(&mut self, client_request: ClientRequest) { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { trace!("received shutdown while handling client request"); } result = self.gateway_transceiver.send_client_request(client_request) => { @@ -138,52 +136,44 @@ impl MixTrafficController { } } - pub fn start(mut self) { - spawn_future!( - async move { - debug!("Started MixTrafficController with graceful shutdown support"); - while !self.task_client.is_shutdown() { - tokio::select! { - biased; - _ = self.task_client.recv() => { - tracing::trace!("MixTrafficController: Received shutdown"); - break; - } - mix_packets = self.mix_rx.recv() => match mix_packets { - Some(mix_packets) => { - if let Err(err) = self.on_messages(mix_packets).await { - error!("Failed to send sphinx packet(s) to the gateway: {err}"); - if self.consecutive_gateway_failure_count == MAX_FAILURE_COUNT { - // Disconnect from the gateway. If we should try to re-connect - // is handled at a higher layer. - error!("Failed to send sphinx packet to the gateway {MAX_FAILURE_COUNT} times in a row - assuming the gateway is dead"); - // Do we need to handle the embedded mixnet client case - // separately? - self.task_client.send_we_stopped(Box::new(ClientCoreError::GatewayFailedToForwardMessages)); - break; - } - } - }, - None => { - tracing::trace!("MixTrafficController: Stopping since channel closed"); + pub async fn run(&mut self) { + debug!("Started MixTrafficController with graceful shutdown support"); + loop { + tokio::select! { + biased; + _ = self.shutdown_token.cancelled() => { + trace!("MixTrafficController: Received shutdown"); + break; + } + mix_packets = self.mix_rx.recv() => match mix_packets { + Some(mix_packets) => { + if let Err(err) = self.on_messages(mix_packets).await { + error!("Failed to send sphinx packet(s) to the gateway: {err}"); + if self.consecutive_gateway_failure_count == MAX_FAILURE_COUNT { + // Disconnect from the gateway. If we should try to re-connect + // is handled at a higher layer. + error!("Failed to send sphinx packet to the gateway {MAX_FAILURE_COUNT} times in a row - assuming the gateway is dead"); + // Do we need to handle the embedded mixnet client case + // separately? break; } - }, - client_request = self.client_rx.recv() => match client_request { - Some(client_request) => { - self.on_client_request(client_request).await; - }, - None => { - tracing::trace!("MixTrafficController, client request channel closed"); - break - } - }, + } + }, + None => { + trace!("MixTrafficController: Stopping since channel closed"); + break; } - } - self.task_client.recv_timeout().await; - tracing::debug!("MixTrafficController: Exiting"); - }, - "MixTrafficController" - ); + }, + client_request = self.client_rx.recv() => match client_request { + Some(client_request) => { + self.on_client_request(client_request).await; + }, + None => { + trace!("MixTrafficController, client request channel closed"); + break} + }, + } + } + debug!("MixTrafficController: Exiting"); } } diff --git a/common/client-core/src/client/real_messages_control/acknowledgement_control/acknowledgement_listener.rs b/common/client-core/src/client/real_messages_control/acknowledgement_control/acknowledgement_listener.rs index 7c7000c3b3d..2167718e57a 100644 --- a/common/client-core/src/client/real_messages_control/acknowledgement_control/acknowledgement_listener.rs +++ b/common/client-core/src/client/real_messages_control/acknowledgement_control/acknowledgement_listener.rs @@ -10,18 +10,17 @@ use nym_sphinx::{ acknowledgements::{identifier::recover_identifier, AckKey}, chunking::fragment::{FragmentIdentifier, COVER_FRAG_ID}, }; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::sync::Arc; use tracing::*; /// Module responsible for listening for any data resembling acknowledgements from the network /// and firing actions to remove them from the 'Pending' state. -pub(super) struct AcknowledgementListener { +pub(crate) struct AcknowledgementListener { ack_key: Arc, ack_receiver: AcknowledgementReceiver, action_sender: AckActionSender, stats_tx: ClientStatsSender, - task_client: TaskClient, } impl AcknowledgementListener { @@ -30,14 +29,12 @@ impl AcknowledgementListener { ack_receiver: AcknowledgementReceiver, action_sender: AckActionSender, stats_tx: ClientStatsSender, - task_client: TaskClient, ) -> Self { AcknowledgementListener { ack_key, ack_receiver, action_sender, stats_tx, - task_client, } } @@ -68,14 +65,9 @@ impl AcknowledgementListener { trace!("Received {frag_id} from the mix network"); self.stats_tx .report(PacketStatisticsEvent::RealAckReceived(ack_content.len()).into()); - if let Err(err) = self + let _ = self .action_sender - .unbounded_send(Action::new_remove(frag_id)) - { - if !self.task_client.is_shutdown_poll() { - error!("Failed to send remove action to action controller: {err}"); - } - } + .unbounded_send(Action::new_remove(frag_id)); } async fn handle_ack_receiver_item(&mut self, item: Vec>) { @@ -85,11 +77,16 @@ impl AcknowledgementListener { } } - pub(super) async fn run(&mut self) { + pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) { debug!("Started AcknowledgementListener with graceful shutdown support"); - while !self.task_client.is_shutdown() { + loop { tokio::select! { + biased; + _ = shutdown_token.cancelled() => { + tracing::trace!("AcknowledgementListener: Received shutdown"); + break; + } acks = self.ack_receiver.next() => match acks { Some(acks) => self.handle_ack_receiver_item(acks).await, None => { @@ -97,12 +94,9 @@ impl AcknowledgementListener { break; } }, - _ = self.task_client.recv() => { - tracing::trace!("AcknowledgementListener: Received shutdown"); - } + } } - self.task_client.recv_timeout().await; tracing::debug!("AcknowledgementListener: Exiting"); } } diff --git a/common/client-core/src/client/real_messages_control/acknowledgement_control/action_controller.rs b/common/client-core/src/client/real_messages_control/acknowledgement_control/action_controller.rs index 350b400e469..6262a37e234 100644 --- a/common/client-core/src/client/real_messages_control/acknowledgement_control/action_controller.rs +++ b/common/client-core/src/client/real_messages_control/acknowledgement_control/action_controller.rs @@ -8,7 +8,7 @@ use futures::StreamExt; use nym_nonexhaustive_delayqueue::{Expired, NonExhaustiveDelayQueue, QueueKey}; use nym_sphinx::chunking::fragment::FragmentIdentifier; use nym_sphinx::Delay as SphinxDelay; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -82,7 +82,7 @@ impl Config { } } -pub(super) struct ActionController { +pub(crate) struct ActionController { /// Configurable parameters of the `ActionController` config: Config, @@ -102,8 +102,6 @@ pub(super) struct ActionController { /// Channel for notifying `RetransmissionRequestListener` about expired acknowledgements. retransmission_sender: RetransmissionRequestSender, - - task_client: TaskClient, } impl ActionController { @@ -111,7 +109,6 @@ impl ActionController { config: Config, retransmission_sender: RetransmissionRequestSender, incoming_actions: AckActionReceiver, - task_client: TaskClient, ) -> Self { ActionController { config, @@ -119,7 +116,6 @@ impl ActionController { pending_acks_timers: NonExhaustiveDelayQueue::new(), incoming_actions, retransmission_sender, - task_client, } } @@ -226,14 +222,9 @@ impl ActionController { // downgrading an arc and then upgrading vs cloning is difference of 30ns vs 15ns // so it's literally a NO difference while it might prevent us from unnecessarily // resending data (in maybe 1 in 1 million cases, but it's something) - if let Err(err) = self + let _ = self .retransmission_sender - .unbounded_send(Arc::downgrade(pending_ack_data)) - { - if !self.task_client.is_shutdown_poll() { - tracing::error!("Failed to send pending ack for retransmission: {err}"); - } - } + .unbounded_send(Arc::downgrade(pending_ack_data)); } else { // this shouldn't cause any issues but shouldn't have happened to begin with! error!("An already removed pending ack has expired") @@ -251,11 +242,16 @@ impl ActionController { } } - pub(super) async fn run(&mut self) { + pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) { debug!("Started ActionController with graceful shutdown support"); - while !self.task_client.is_shutdown() { + loop { tokio::select! { + biased; + _ = shutdown_token.cancelled() => { + tracing::trace!("ActionController: Received shutdown"); + break; + } action = self.incoming_actions.next() => match action { Some(action) => self.process_action(action), None => { @@ -272,13 +268,8 @@ impl ActionController { break; } }, - _ = self.task_client.recv() => { - tracing::trace!("ActionController: Received shutdown"); - break; - } } } - self.task_client.recv_timeout().await; tracing::debug!("ActionController: Exiting"); } } diff --git a/common/client-core/src/client/real_messages_control/acknowledgement_control/input_message_listener.rs b/common/client-core/src/client/real_messages_control/acknowledgement_control/input_message_listener.rs index f47f050b3b4..69ca92709fc 100644 --- a/common/client-core/src/client/real_messages_control/acknowledgement_control/input_message_listener.rs +++ b/common/client-core/src/client/real_messages_control/acknowledgement_control/input_message_listener.rs @@ -10,21 +10,20 @@ use nym_sphinx::anonymous_replies::requests::AnonymousSenderTag; use nym_sphinx::forwarding::packet::MixPacket; use nym_sphinx::params::PacketType; use nym_task::connections::TransmissionLane; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use rand::{CryptoRng, Rng}; use tracing::*; /// Module responsible for dealing with the received messages: splitting them, creating acknowledgements, /// putting everything into sphinx packets, etc. /// It also makes an initial sending attempt for said messages. -pub(super) struct InputMessageListener +pub(crate) struct InputMessageListener where R: CryptoRng + Rng, { input_receiver: InputMessageReceiver, message_handler: MessageHandler, reply_controller_sender: ReplyControllerSender, - task_client: TaskClient, } impl InputMessageListener @@ -38,13 +37,11 @@ where input_receiver: InputMessageReceiver, message_handler: MessageHandler, reply_controller_sender: ReplyControllerSender, - task_client: TaskClient, ) -> Self { InputMessageListener { input_receiver, message_handler, reply_controller_sender, - task_client, } } @@ -68,14 +65,9 @@ where max_retransmissions: Option, ) { // offload reply handling to the dedicated task - if let Err(err) = + let _ = self.reply_controller_sender - .send_reply(recipient_tag, data, lane, max_retransmissions) - { - if !self.task_client.is_shutdown_poll() { - error!("failed to send a reply - {err}"); - } - } + .send_reply(recipient_tag, data, lane, max_retransmissions); } async fn handle_plain_message( @@ -221,13 +213,13 @@ where }; } - pub(super) async fn run(&mut self) { + pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) { debug!("Started InputMessageListener with graceful shutdown support"); - while !self.task_client.is_shutdown() { + loop { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = shutdown_token.cancelled() => { tracing::trace!("InputMessageListener: Received shutdown"); break; } @@ -243,7 +235,6 @@ where } } - self.task_client.recv_timeout().await; tracing::debug!("InputMessageListener: Exiting"); } } diff --git a/common/client-core/src/client/real_messages_control/acknowledgement_control/mod.rs b/common/client-core/src/client/real_messages_control/acknowledgement_control/mod.rs index ea332ebdaf4..fd17348ee95 100644 --- a/common/client-core/src/client/real_messages_control/acknowledgement_control/mod.rs +++ b/common/client-core/src/client/real_messages_control/acknowledgement_control/mod.rs @@ -10,7 +10,6 @@ use self::{ use crate::client::inbound_messages::InputMessageReceiver; use crate::client::real_messages_control::message_handler::MessageHandler; use crate::client::replies::reply_controller::ReplyControllerSender; -use crate::spawn_future; use action_controller::AckActionReceiver; use futures::channel::mpsc; use nym_gateway_client::AcknowledgementReceiver; @@ -23,13 +22,11 @@ use nym_sphinx::{ Delay as SphinxDelay, }; use nym_statistics_common::clients::ClientStatsSender; -use nym_task::TaskClient; use rand::{CryptoRng, Rng}; use std::{ sync::{Arc, Weak}, time::Duration, }; -use tracing::*; pub(crate) use action_controller::{AckActionSender, Action}; @@ -190,6 +187,9 @@ pub(super) struct Config { /// Predefined packet size used for the encapsulated messages. packet_size: PacketSize, + + /// Type of packets used for retransmissions + packet_type: PacketType, } impl Config { @@ -197,12 +197,14 @@ impl Config { maximum_retransmissions: Option, ack_wait_addition: Duration, ack_wait_multiplier: f64, + packet_type: PacketType, ) -> Self { Config { maximum_retransmissions, ack_wait_addition, ack_wait_multiplier, packet_size: Default::default(), + packet_type, } } @@ -212,7 +214,7 @@ impl Config { } } -pub(super) struct AcknowledgementController +pub(crate) struct AcknowledgementController where R: CryptoRng + Rng, { @@ -234,7 +236,6 @@ where message_handler: MessageHandler, reply_controller_sender: ReplyControllerSender, stats_tx: ClientStatsSender, - task_client: TaskClient, ) -> Self { let (retransmission_tx, retransmission_rx) = mpsc::unbounded(); @@ -244,7 +245,6 @@ where action_config, retransmission_tx, connectors.ack_action_receiver, - task_client.fork("action_controller"), ); // will listen for any acks coming from the network @@ -253,7 +253,6 @@ where connectors.ack_receiver, connectors.ack_action_sender.clone(), stats_tx, - task_client.fork("acknowledgement_listener"), ); // will listen for any new messages from the client @@ -261,7 +260,6 @@ where connectors.input_receiver, message_handler.clone(), reply_controller_sender.clone(), - task_client.fork("input_message_listener"), ); // will listen for any ack timeouts and trigger retransmission @@ -271,16 +269,13 @@ where message_handler, retransmission_rx, reply_controller_sender, - task_client.fork("retransmission_request_listener"), + config.packet_type, ); // will listen for events indicating the packet was sent through the network so that // the retransmission timer should be started. - let sent_notification_listener = SentNotificationListener::new( - connectors.sent_notifier, - connectors.ack_action_sender, - task_client.with_suffix("sent_notification_listener"), - ); + let sent_notification_listener = + SentNotificationListener::new(connectors.sent_notifier, connectors.ack_action_sender); AcknowledgementController { acknowledgement_listener, @@ -291,51 +286,21 @@ where } } - pub(super) fn start(self, packet_type: PacketType) { - let mut acknowledgement_listener = self.acknowledgement_listener; - let mut input_message_listener = self.input_message_listener; - let mut retransmission_request_listener = self.retransmission_request_listener; - let mut sent_notification_listener = self.sent_notification_listener; - let mut action_controller = self.action_controller; - - spawn_future!( - async move { - acknowledgement_listener.run().await; - debug!("The acknowledgement listener has finished execution!"); - }, - "AcknowledgementController::AcknowledgementListener" - ); - - spawn_future!( - async move { - input_message_listener.run().await; - debug!("The input listener has finished execution!"); - }, - "AcknowledgementController::InputMessageListener" - ); - - spawn_future!( - async move { - retransmission_request_listener.run(packet_type).await; - debug!("The retransmission request listener has finished execution!"); - }, - "AcknowledgementController::RetransmissionRequestListener" - ); - - spawn_future!( - async move { - sent_notification_listener.run().await; - debug!("The sent notification listener has finished execution!"); - }, - "AcknowledgementController::SentNotificationListener" - ); - - spawn_future!( - async move { - action_controller.run().await; - debug!("The controller has finished execution!"); - }, - "AcknowledgementController::ActionController" - ); + pub(crate) fn into_tasks( + self, + ) -> ( + AcknowledgementListener, + InputMessageListener, + RetransmissionRequestListener, + SentNotificationListener, + ActionController, + ) { + ( + self.acknowledgement_listener, + self.input_message_listener, + self.retransmission_request_listener, + self.sent_notification_listener, + self.action_controller, + ) } } diff --git a/common/client-core/src/client/real_messages_control/acknowledgement_control/retransmission_request_listener.rs b/common/client-core/src/client/real_messages_control/acknowledgement_control/retransmission_request_listener.rs index 0b066d0cc0b..597bfecf567 100644 --- a/common/client-core/src/client/real_messages_control/acknowledgement_control/retransmission_request_listener.rs +++ b/common/client-core/src/client/real_messages_control/acknowledgement_control/retransmission_request_listener.rs @@ -13,19 +13,19 @@ use futures::StreamExt; use nym_sphinx::chunking::fragment::Fragment; use nym_sphinx::preparer::PreparedFragment; use nym_sphinx::{addressing::clients::Recipient, params::PacketType}; -use nym_task::{connections::TransmissionLane, TaskClient}; +use nym_task::{connections::TransmissionLane, ShutdownToken}; use rand::{CryptoRng, Rng}; use std::sync::{Arc, Weak}; use tracing::*; // responsible for packet retransmission upon fired timer -pub(super) struct RetransmissionRequestListener { +pub(crate) struct RetransmissionRequestListener { maximum_retransmissions: Option, action_sender: AckActionSender, message_handler: MessageHandler, request_receiver: RetransmissionRequestReceiver, reply_controller_sender: ReplyControllerSender, - task_client: TaskClient, + packet_type: PacketType, } impl RetransmissionRequestListener @@ -38,7 +38,7 @@ where message_handler: MessageHandler, request_receiver: RetransmissionRequestReceiver, reply_controller_sender: ReplyControllerSender, - task_client: TaskClient, + packet_type: PacketType, ) -> Self { RetransmissionRequestListener { maximum_retransmissions, @@ -46,7 +46,7 @@ where message_handler, request_receiver, reply_controller_sender, - task_client, + packet_type, } } @@ -67,7 +67,6 @@ where async fn on_retransmission_request( &mut self, weak_timed_out_ack: Weak, - packet_type: PacketType, ) { let timed_out_ack = match weak_timed_out_ack.upgrade() { Some(timed_out_ack) => timed_out_ack, @@ -97,22 +96,18 @@ where } => { // if this is retransmission for reply, offload it to the dedicated task // that deals with all the surbs - if let Err(err) = self.reply_controller_sender.send_retransmission_data( + let _ = self.reply_controller_sender.send_retransmission_data( *recipient_tag, weak_timed_out_ack, *extra_surb_request, - ) { - if !self.task_client.is_shutdown_poll() { - error!("Failed to send retransmission data to the reply controller: {err}"); - } - } + ); return; } PacketDestination::KnownRecipient(recipient) => { self.prepare_normal_retransmission_chunk( **recipient, timed_out_ack.message_chunk.clone(), - packet_type, + self.packet_type, ) .await } @@ -153,14 +148,9 @@ where // is sent to the `OutQueueControl` and has gone through its internal queue // with the additional poisson delay. // And since Actions are executed in order `UpdateTimer` will HAVE TO be executed before `StartTimer` - if let Err(err) = self + let _ = self .action_sender - .unbounded_send(Action::new_update_pending_ack(frag_id, new_delay)) - { - if !self.task_client.is_shutdown_poll() { - error!("Failed to send update pending ack action to the controller: {err}"); - } - } + .unbounded_send(Action::new_update_pending_ack(frag_id, new_delay)); // send to `OutQueueControl` to eventually send to the mix network self.message_handler @@ -174,18 +164,18 @@ where .await } - pub(super) async fn run(&mut self, packet_type: PacketType) { + pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) { debug!("Started RetransmissionRequestListener with graceful shutdown support"); - while !self.task_client.is_shutdown() { + loop { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = shutdown_token.cancelled() => { tracing::trace!("RetransmissionRequestListener: Received shutdown"); break; } timed_out_ack = self.request_receiver.next() => match timed_out_ack { - Some(timed_out_ack) => self.on_retransmission_request(timed_out_ack, packet_type).await, + Some(timed_out_ack) => self.on_retransmission_request(timed_out_ack).await, None => { tracing::trace!("RetransmissionRequestListener: Stopping since channel closed"); break; @@ -194,7 +184,6 @@ where } } - self.task_client.recv_timeout().await; tracing::debug!("RetransmissionRequestListener: Exiting"); } } diff --git a/common/client-core/src/client/real_messages_control/acknowledgement_control/sent_notification_listener.rs b/common/client-core/src/client/real_messages_control/acknowledgement_control/sent_notification_listener.rs index 9ee9cdf4614..02560805a8e 100644 --- a/common/client-core/src/client/real_messages_control/acknowledgement_control/sent_notification_listener.rs +++ b/common/client-core/src/client/real_messages_control/acknowledgement_control/sent_notification_listener.rs @@ -5,29 +5,25 @@ use super::action_controller::{AckActionSender, Action}; use super::SentPacketNotificationReceiver; use futures::StreamExt; use nym_sphinx::chunking::fragment::{FragmentIdentifier, COVER_FRAG_ID}; -use nym_task::TaskClient; use tracing::*; /// Module responsible for starting up retransmission timers. /// It is required because when we send our packet to the `real traffic stream` controlled /// by a poisson timer, there's no guarantee the message will be sent immediately, so we might /// accidentally fire retransmission way quicker than we should have. -pub(super) struct SentNotificationListener { +pub(crate) struct SentNotificationListener { sent_notifier: SentPacketNotificationReceiver, action_sender: AckActionSender, - task_client: TaskClient, } impl SentNotificationListener { pub(super) fn new( sent_notifier: SentPacketNotificationReceiver, action_sender: AckActionSender, - task_client: TaskClient, ) -> Self { SentNotificationListener { sent_notifier, action_sender, - task_client, } } @@ -36,37 +32,18 @@ impl SentNotificationListener { trace!("sent off a cover message - no need to start retransmission timer!"); return; } - if let Err(err) = self + let _ = self .action_sender - .unbounded_send(Action::new_start_timer(frag_id)) - { - if !self.task_client.is_shutdown_poll() { - error!("Failed to send start timer action to action controller: {err}"); - } - } + .unbounded_send(Action::new_start_timer(frag_id)); } - pub(super) async fn run(&mut self) { + pub(crate) async fn run(&mut self) { debug!("Started SentNotificationListener with graceful shutdown support"); - while !self.task_client.is_shutdown() { - tokio::select! { - frag_id = self.sent_notifier.next() => match frag_id { - Some(frag_id) => { - self.on_sent_message(frag_id).await; - } - None => { - tracing::trace!("SentNotificationListener: Stopping since channel closed"); - break; - } - }, - _ = self.task_client.recv() => { - tracing::trace!("SentNotificationListener: Received shutdown"); - break; - } - } + while let Some(frag_id) = self.sent_notifier.next().await { + self.on_sent_message(frag_id).await; } - assert!(self.task_client.is_shutdown_poll()); + tracing::debug!("SentNotificationListener: Exiting"); } } diff --git a/common/client-core/src/client/real_messages_control/message_handler.rs b/common/client-core/src/client/real_messages_control/message_handler.rs index 1dca40cf7e7..366b2660901 100644 --- a/common/client-core/src/client/real_messages_control/message_handler.rs +++ b/common/client-core/src/client/real_messages_control/message_handler.rs @@ -20,7 +20,7 @@ use nym_sphinx::params::{PacketSize, PacketType}; use nym_sphinx::preparer::{MessagePreparer, PreparedFragment}; use nym_sphinx::Delay; use nym_task::connections::TransmissionLane; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use nym_topology::{NymRouteProvider, NymTopologyError}; use rand::{CryptoRng, Rng}; use std::collections::HashMap; @@ -189,7 +189,7 @@ pub(crate) struct MessageHandler { topology_access: TopologyAccessor, reply_key_storage: SentReplyKeys, tag_storage: UsedSenderTags, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl MessageHandler @@ -205,7 +205,7 @@ where topology_access: TopologyAccessor, reply_key_storage: SentReplyKeys, tag_storage: UsedSenderTags, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self where R: Copy, @@ -228,7 +228,7 @@ where topology_access, reply_key_storage, tag_storage, - task_client, + shutdown_token, } } @@ -712,7 +712,7 @@ where .action_sender .unbounded_send(Action::UpdatePendingAck(id, new_delay)) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("Failed to send update action to the controller: {err}"); } } @@ -723,7 +723,7 @@ where .action_sender .unbounded_send(Action::new_insert(pending_acks)) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("Failed to send insert action to the controller: {err}"); } } @@ -737,7 +737,7 @@ where ) { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { trace!("received shutdown while attempting to forward mixnet messages"); } sending_res = self.real_message_sender.send((messages, transmission_lane)) => { diff --git a/common/client-core/src/client/real_messages_control/mod.rs b/common/client-core/src/client/real_messages_control/mod.rs index b169a9bff8e..9b852535fa9 100644 --- a/common/client-core/src/client/real_messages_control/mod.rs +++ b/common/client-core/src/client/real_messages_control/mod.rs @@ -14,26 +14,21 @@ use crate::client::replies::reply_controller::{ ReplyController, ReplyControllerReceiver, ReplyControllerSender, }; use crate::client::replies::reply_storage::CombinedReplyStorage; -use crate::config; -use crate::{ - client::{ - inbound_messages::InputMessageReceiver, mix_traffic::BatchMixMessageSender, - real_messages_control::acknowledgement_control::AcknowledgementControllerConnectors, - topology_control::TopologyAccessor, - }, - spawn_future, +use crate::client::{ + inbound_messages::InputMessageReceiver, mix_traffic::BatchMixMessageSender, + real_messages_control::acknowledgement_control::AcknowledgementControllerConnectors, + topology_control::TopologyAccessor, }; +use crate::config; use futures::channel::mpsc; use nym_gateway_client::AcknowledgementReceiver; use nym_sphinx::acknowledgements::AckKey; use nym_sphinx::addressing::clients::Recipient; -use nym_sphinx::params::PacketType; use nym_statistics_common::clients::ClientStatsSender; use nym_task::connections::{ConnectionCommandReceiver, LaneQueueLengths}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use rand::{rngs::OsRng, CryptoRng, Rng}; use std::sync::Arc; -use tracing::*; use crate::client::replies::reply_controller::key_rotation_helpers::KeyRotationConfig; pub(crate) use acknowledgement_control::{AckActionSender, Action}; @@ -69,6 +64,7 @@ impl<'a> From<&'a Config> for acknowledgement_control::Config { cfg.traffic.maximum_number_of_retransmissions, cfg.acks.ack_wait_addition, cfg.acks.ack_wait_multiplier, + cfg.traffic.packet_type, ) .with_custom_packet_size(cfg.traffic.primary_packet_size) } @@ -146,7 +142,7 @@ impl RealMessagesController { lane_queue_lengths: LaneQueueLengths, client_connection_rx: ConnectionCommandReceiver, stats_tx: ClientStatsSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { let rng = OsRng; @@ -178,7 +174,7 @@ impl RealMessagesController { topology_access.clone(), reply_storage.key_storage(), reply_storage.tags_storage(), - task_client.fork("message_handler"), + shutdown_token.clone(), ); let ack_control = AcknowledgementController::new( @@ -188,7 +184,6 @@ impl RealMessagesController { message_handler.clone(), reply_controller_sender, stats_tx.clone(), - task_client.fork("ack_control"), ); let reply_control = ReplyController::new( @@ -196,7 +191,6 @@ impl RealMessagesController { message_handler, reply_storage, reply_controller_receiver, - task_client.fork("reply_controller"), ); let out_queue_control = OutQueueControl::new( @@ -209,7 +203,7 @@ impl RealMessagesController { lane_queue_lengths, client_connection_rx, stats_tx, - task_client.with_suffix("out_queue_control"), + shutdown_token.clone(), ); RealMessagesController { @@ -219,26 +213,13 @@ impl RealMessagesController { } } - pub fn start(self, packet_type: PacketType) { - let mut out_queue_control = self.out_queue_control; - let ack_control = self.ack_control; - let mut reply_control = self.reply_control; - - spawn_future!( - async move { - out_queue_control.run().await; - debug!("The out queue controller has finished execution!"); - }, - "RealMessagesController::OutQueueControl)" - ); - spawn_future!( - async move { - reply_control.run().await; - debug!("The reply controller has finished execution!"); - }, - "RealMessagesController::ReplyController" - ); - - ack_control.start(packet_type); + pub fn into_tasks( + self, + ) -> ( + OutQueueControl, + ReplyController, + AcknowledgementController, + ) { + (self.out_queue_control, self.reply_control, self.ack_control) } } diff --git a/common/client-core/src/client/real_messages_control/real_traffic_stream.rs b/common/client-core/src/client/real_messages_control/real_traffic_stream.rs index 7cbf3b6ad65..92faffdbf23 100644 --- a/common/client-core/src/client/real_messages_control/real_traffic_stream.rs +++ b/common/client-core/src/client/real_messages_control/real_traffic_stream.rs @@ -21,7 +21,7 @@ use nym_statistics_common::clients::{packet_statistics::PacketStatisticsEvent, C use nym_task::connections::{ ConnectionCommand, ConnectionCommandReceiver, ConnectionId, LaneQueueLengths, TransmissionLane, }; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use rand::{CryptoRng, Rng}; use std::pin::Pin; use std::sync::Arc; @@ -119,7 +119,7 @@ where /// Channel used for sending metrics events (specifically `PacketStatistics` events) to the metrics tracker. stats_tx: ClientStatsSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, } #[derive(Debug)] @@ -179,7 +179,7 @@ where lane_queue_lengths: LaneQueueLengths, client_connection_rx: ConnectionCommandReceiver, stats_tx: ClientStatsSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { OutQueueControl { config, @@ -194,7 +194,7 @@ where client_connection_rx, lane_queue_lengths, stats_tx, - task_client, + shutdown_token, } } @@ -282,7 +282,7 @@ where let sending_res = tokio::select! { biased; - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { trace!("received shutdown signal while attempting to send mix message"); return } @@ -293,7 +293,7 @@ where match sending_res { Err(_) => { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { tracing::error!( "failed to send mixnet packet due to closed channel (outside of shutdown!)" ); @@ -536,9 +536,7 @@ where } #[cfg(not(target_arch = "wasm32"))] - fn log_status(&self, shutdown: &mut TaskClient) { - use crate::error::ClientCoreStatusMessage; - + fn log_status(&self) { let packets = self.transmission_buffer.total_size(); let lanes = self.transmission_buffer.lanes(); let mult = self.sending_delay_controller.current_multiplier(); @@ -567,32 +565,33 @@ where tracing::debug!("{status_str}"); } - // Send status message to whoever is listening (possibly UI) - if mult == self.sending_delay_controller.max_multiplier() { - shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsVerySlow)); - } else if mult > self.sending_delay_controller.min_multiplier() { - shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsSlow)); - } + // leave the code commented in case somebody wanted to restore this logic with a different channel + // // Send status message to whoever is listening (possibly UI) + // if mult == self.sending_delay_controller.max_multiplier() { + // shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsVerySlow)); + // } else if mult > self.sending_delay_controller.min_multiplier() { + // shutdown.send_status_msg(Box::new(ClientCoreStatusMessage::GatewayIsSlow)); + // } } - pub(super) async fn run(&mut self) { + pub(crate) async fn run(&mut self) { debug!("Started OutQueueControl with graceful shutdown support"); - let mut shutdown = self.task_client.fork("select"); - + // avoid borrow on self + let shutdown_token = self.shutdown_token.clone(); #[cfg(not(target_arch = "wasm32"))] { let mut status_timer = tokio::time::interval(Duration::from_secs(5)); - while !shutdown.is_shutdown() { + loop { tokio::select! { biased; - _ = shutdown.recv() => { + _ = shutdown_token.cancelled() => { tracing::trace!("OutQueueControl: Received shutdown"); break; } _ = status_timer.tick() => { - self.log_status(&mut shutdown); + self.log_status(); } next_message = self.next() => if let Some(next_message) = next_message { self.on_message(next_message).await; @@ -602,16 +601,16 @@ where } } } - shutdown.recv_timeout().await; } #[cfg(target_arch = "wasm32")] { - while !shutdown.is_shutdown() { + loop { tokio::select! { biased; - _ = shutdown.recv() => { + _ = shutdown_token.cancelled() => { tracing::trace!("OutQueueControl: Received shutdown"); + break; } next_message = self.next() => if let Some(next_message) = next_message { self.on_message(next_message).await; diff --git a/common/client-core/src/client/real_messages_control/real_traffic_stream/sending_delay_controller.rs b/common/client-core/src/client/real_messages_control/real_traffic_stream/sending_delay_controller.rs index 7c147bed8d6..fa9898e42bd 100644 --- a/common/client-core/src/client/real_messages_control/real_traffic_stream/sending_delay_controller.rs +++ b/common/client-core/src/client/real_messages_control/real_traffic_stream/sending_delay_controller.rs @@ -83,11 +83,13 @@ impl SendingDelayController { self.current_multiplier } + #[allow(dead_code)] #[cfg(not(target_arch = "wasm32"))] pub(crate) fn min_multiplier(&self) -> u32 { self.lower_bound } + #[allow(dead_code)] #[cfg(not(target_arch = "wasm32"))] pub(crate) fn max_multiplier(&self) -> u32 { self.upper_bound diff --git a/common/client-core/src/client/received_buffer.rs b/common/client-core/src/client/received_buffer.rs index 834d8b9be2c..b31cccb4198 100644 --- a/common/client-core/src/client/received_buffer.rs +++ b/common/client-core/src/client/received_buffer.rs @@ -5,7 +5,6 @@ use crate::client::helpers::get_time_now; use crate::client::replies::{ reply_controller::ReplyControllerSender, reply_storage::SentReplyKeys, }; -use crate::spawn_future; use futures::channel::mpsc; use futures::lock::Mutex; use futures::StreamExt; @@ -20,7 +19,7 @@ use nym_sphinx::message::{NymMessage, PlainMessage}; use nym_sphinx::params::ReplySurbKeyDigestAlgorithm; use nym_sphinx::receiver::{MessageReceiver, MessageRecoveryError, ReconstructedMessage}; use nym_statistics_common::clients::{packet_statistics::PacketStatisticsEvent, ClientStatsSender}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::collections::HashSet; use std::sync::Arc; use std::time::Duration; @@ -172,7 +171,7 @@ struct ReceivedMessagesBuffer { inner: Arc>>, reply_key_storage: SentReplyKeys, reply_controller_sender: ReplyControllerSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl ReceivedMessagesBuffer { @@ -181,7 +180,7 @@ impl ReceivedMessagesBuffer { reply_key_storage: SentReplyKeys, reply_controller_sender: ReplyControllerSender, stats_tx: ClientStatsSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { ReceivedMessagesBuffer { inner: Arc::new(Mutex::new(ReceivedMessagesBufferInner { @@ -195,7 +194,7 @@ impl ReceivedMessagesBuffer { })), reply_key_storage, reply_controller_sender, - task_client, + shutdown_token, } } @@ -316,7 +315,7 @@ impl ReceivedMessagesBuffer { reply_surbs, from_surb_request, ) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("{err}"); } } @@ -339,7 +338,7 @@ impl ReceivedMessagesBuffer { .reply_controller_sender .send_additional_surbs_request(*recipient, amount) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { error!("{err}"); } } @@ -466,22 +465,22 @@ pub enum ReceivedBufferMessage { ReceiverDisconnect, } -struct RequestReceiver { +pub(crate) struct RequestReceiver { received_buffer: ReceivedMessagesBuffer, query_receiver: ReceivedBufferRequestReceiver, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl RequestReceiver { fn new( received_buffer: ReceivedMessagesBuffer, query_receiver: ReceivedBufferRequestReceiver, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { RequestReceiver { received_buffer, query_receiver, - task_client, + shutdown_token, } } @@ -496,66 +495,70 @@ impl RequestReceiver { } } - async fn run(&mut self) { + pub(crate) async fn run(&mut self) { debug!("Started RequestReceiver with graceful shutdown support"); - while !self.task_client.is_shutdown() { + loop { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { tracing::trace!("RequestReceiver: Received shutdown"); + break; } request = self.query_receiver.next() => { if let Some(message) = request { self.handle_message(message).await } else { tracing::trace!("RequestReceiver: Stopping since channel closed"); + self.shutdown_token.cancelled().await; break; } }, } } - self.task_client.recv().await; tracing::debug!("RequestReceiver: Exiting"); } } -struct FragmentedMessageReceiver { +pub(crate) struct FragmentedMessageReceiver { received_buffer: ReceivedMessagesBuffer, mixnet_packet_receiver: MixnetMessageReceiver, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl FragmentedMessageReceiver { fn new( received_buffer: ReceivedMessagesBuffer, mixnet_packet_receiver: MixnetMessageReceiver, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { FragmentedMessageReceiver { received_buffer, mixnet_packet_receiver, - task_client, + shutdown_token, } } - async fn run(&mut self) -> Result<(), MessageRecoveryError> { + pub(crate) async fn run(&mut self) -> Result<(), MessageRecoveryError> { debug!("Started FragmentedMessageReceiver with graceful shutdown support"); - while !self.task_client.is_shutdown() { + loop { tokio::select! { + biased; + _ = self.shutdown_token.cancelled() => { + tracing::trace!("FragmentedMessageReceiver: Received shutdown"); + break; + } new_messages = self.mixnet_packet_receiver.next() => { if let Some(new_messages) = new_messages { self.received_buffer.handle_new_received(new_messages).await?; } else { tracing::trace!("FragmentedMessageReceiver: Stopping since channel closed"); + self.shutdown_token.cancelled().await; break; } }, - _ = self.task_client.recv_with_delay() => { - tracing::trace!("FragmentedMessageReceiver: Received shutdown"); - } + } } - self.task_client.recv_timeout().await; tracing::debug!("FragmentedMessageReceiver: Exiting"); Ok(()) } @@ -574,48 +577,31 @@ impl ReceivedMessagesBufferControll reply_key_storage: SentReplyKeys, reply_controller_sender: ReplyControllerSender, metrics_reporter: ClientStatsSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { let received_buffer = ReceivedMessagesBuffer::new( local_encryption_keypair, reply_key_storage, reply_controller_sender, metrics_reporter, - task_client.fork("received_messages_buffer"), + shutdown_token.clone(), ); ReceivedMessagesBufferController { fragmented_message_receiver: FragmentedMessageReceiver::new( received_buffer.clone(), mixnet_packet_receiver, - task_client.fork("fragmented_message_receiver"), + shutdown_token.clone(), ), request_receiver: RequestReceiver::new( received_buffer, query_receiver, - task_client.with_suffix("request_receiver"), + shutdown_token.clone(), ), } } - pub fn start(self) { - let mut fragmented_message_receiver = self.fragmented_message_receiver; - let mut request_receiver = self.request_receiver; - - spawn_future!( - async move { - match fragmented_message_receiver.run().await { - Ok(_) => {} - Err(e) => error!("{e}"), - } - }, - "ReceivedMessagesBufferController::FragmentedMessageReceiver" - ); - spawn_future!( - async move { - request_receiver.run().await; - }, - "ReceivedMessagesBufferController::RequestReceiver" - ); + pub(crate) fn into_tasks(self) -> (FragmentedMessageReceiver, RequestReceiver) { + (self.fragmented_message_receiver, self.request_receiver) } } diff --git a/common/client-core/src/client/replies/reply_controller/mod.rs b/common/client-core/src/client/replies/reply_controller/mod.rs index 60004c20118..63ead1c0be5 100644 --- a/common/client-core/src/client/replies/reply_controller/mod.rs +++ b/common/client-core/src/client/replies/reply_controller/mod.rs @@ -7,7 +7,7 @@ use crate::client::replies::reply_controller::key_rotation_helpers::KeyRotationC use crate::client::replies::reply_storage::CombinedReplyStorage; use crate::config; use futures::StreamExt; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use rand::rngs::OsRng; use rand::{CryptoRng, Rng}; use std::time::Duration; @@ -60,9 +60,6 @@ pub struct ReplyController { receiver_controller: ReceiverReplyController, request_receiver: ReplyControllerReceiver, - - // Listen for shutdown signals - task_client: TaskClient, } impl ReplyController { @@ -71,7 +68,6 @@ impl ReplyController { message_handler: MessageHandler, full_reply_storage: CombinedReplyStorage, request_receiver: ReplyControllerReceiver, - task_client: TaskClient, ) -> Self { ReplyController { config, @@ -86,7 +82,6 @@ impl ReplyController { message_handler, ), request_receiver, - task_client, } } } @@ -148,22 +143,21 @@ where self.sender_controller.inspect_and_clear_stale_data(now) } - pub(crate) async fn run(&mut self) { + pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) { debug!("Started ReplyController with graceful shutdown support"); - let mut shutdown = self.task_client.fork("reply-controller"); - let polling_rate = Duration::from_secs(5); let mut stale_inspection = new_interval_stream(polling_rate); let polling_rate = self.config.key_rotation.epoch_duration / 8; let mut invalidation_inspection = new_interval_stream(polling_rate); - while !shutdown.is_shutdown() { + loop { tokio::select! { biased; - _ = shutdown.recv() => { + _ = shutdown_token.cancelled() => { tracing::trace!("ReplyController: Received shutdown"); + break; }, req = self.request_receiver.next() => match req { Some(req) => self.handle_request(req).await, @@ -181,7 +175,6 @@ where } } } - assert!(shutdown.is_shutdown_poll()); tracing::debug!("ReplyController: Exiting"); } } diff --git a/common/client-core/src/client/statistics_control.rs b/common/client-core/src/client/statistics_control.rs index 6d315f4eb2e..dcfbd2e19c5 100644 --- a/common/client-core/src/client/statistics_control.rs +++ b/common/client-core/src/client/statistics_control.rs @@ -16,21 +16,17 @@ #![warn(clippy::todo)] #![warn(clippy::dbg_macro)] +use crate::client::inbound_messages::{InputMessage, InputMessageSender}; use futures::StreamExt; use nym_client_core_config_types::StatsReporting; use nym_sphinx::addressing::Recipient; use nym_statistics_common::clients::{ ClientStatsController, ClientStatsReceiver, ClientStatsSender, }; -use nym_task::{connections::TransmissionLane, TaskClient}; +use nym_task::{connections::TransmissionLane, ShutdownToken, ShutdownTracker}; use std::time::Duration; -use crate::{ - client::inbound_messages::{InputMessage, InputMessageSender}, - spawn_future, -}; - -/// Time interval between reporting statistics locally (logging/task_client) +/// Time interval between reporting statistics locally (logging/shutdown_token) const LOCAL_REPORT_INTERVAL: Duration = Duration::from_secs(2); /// Interval for taking snapshots of the statistics const SNAPSHOT_INTERVAL: Duration = Duration::from_millis(500); @@ -51,9 +47,6 @@ pub(crate) struct StatisticsControl { /// Config for stats reporting (enabled, address, interval) reporting_config: StatsReporting, - - /// Task client for listening for shutdown - task_client: TaskClient, } impl StatisticsControl { @@ -62,24 +55,20 @@ impl StatisticsControl { client_type: String, client_stats_id: String, report_tx: InputMessageSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> (Self, ClientStatsSender) { let (stats_tx, stats_rx) = tokio::sync::mpsc::unbounded_channel(); let stats = ClientStatsController::new(client_stats_id, client_type); - let mut task_client_stats_sender = task_client.fork("stats_sender"); - task_client_stats_sender.disarm(); - ( StatisticsControl { stats, stats_rx, report_tx, reporting_config, - task_client, }, - ClientStatsSender::new(Some(stats_tx), task_client_stats_sender), + ClientStatsSender::new(Some(stats_tx), shutdown_token), ) } @@ -99,7 +88,8 @@ impl StatisticsControl { } } - async fn run(&mut self) { + // manually control the shutdown mechanism as we don't want to get interrupted mid-snapshot + pub async fn run(&mut self, shutdown_token: ShutdownToken) { tracing::debug!("Started StatisticsControl with graceful shutdown support"); #[cfg(not(target_arch = "wasm32"))] @@ -129,10 +119,10 @@ impl StatisticsControl { let mut snapshot_interval = gloo_timers::future::IntervalStream::new(SNAPSHOT_INTERVAL.as_millis() as u32); - while !self.task_client.is_shutdown() { + loop { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = shutdown_token.cancelled() => { tracing::trace!("StatisticsControl: Received shutdown"); break; }, @@ -157,37 +147,34 @@ impl StatisticsControl { } _ = local_report_interval.next() => { - self.stats.local_report(&mut self.task_client); + self.stats.local_report(); } } } tracing::debug!("StatisticsControl: Exiting"); } - pub(crate) fn start(mut self) { - spawn_future!( - async move { - self.run().await; - }, - "StatisticsControl" - ) - } - pub(crate) fn create_and_start( reporting_config: StatsReporting, client_type: String, client_stats_id: String, report_tx: InputMessageSender, - task_client: TaskClient, + shutdown_tracker: &ShutdownTracker, ) -> ClientStatsSender { - let (controller, sender) = Self::create( + let (mut controller, sender) = Self::create( reporting_config, client_type, client_stats_id, report_tx, - task_client, + shutdown_tracker.child_shutdown_token(), + ); + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + shutdown_tracker.try_spawn_named( + async move { + controller.run(shutdown_token).await; + }, + "StatisticsControl", ); - controller.start(); sender } } diff --git a/common/client-core/src/client/topology_control/mod.rs b/common/client-core/src/client/topology_control/mod.rs index b3c4d61f3f6..c083aa73788 100644 --- a/common/client-core/src/client/topology_control/mod.rs +++ b/common/client-core/src/client/topology_control/mod.rs @@ -1,11 +1,9 @@ // Copyright 2021-2023 - Nym Technologies SA // SPDX-License-Identifier: Apache-2.0 -use crate::spawn_future; pub(crate) use accessor::{TopologyAccessor, TopologyReadPermit}; use futures::StreamExt; use nym_sphinx::addressing::nodes::NodeIdentity; -use nym_task::TaskClient; use nym_topology::NymTopologyError; use std::time::Duration; use tracing::*; @@ -41,8 +39,6 @@ pub struct TopologyRefresher { refresh_rate: Duration, consecutive_failure_count: usize, - - task_client: TaskClient, } impl TopologyRefresher { @@ -50,14 +46,12 @@ impl TopologyRefresher { cfg: TopologyRefresherConfig, topology_accessor: TopologyAccessor, topology_provider: Box, - task_client: TaskClient, ) -> Self { TopologyRefresher { topology_provider, topology_accessor, refresh_rate: cfg.refresh_rate, consecutive_failure_count: 0, - task_client, } } @@ -144,40 +138,30 @@ impl TopologyRefresher { } } - pub fn start(mut self) { - spawn_future!( - async move { - debug!("Started TopologyRefresher with graceful shutdown support"); - - #[cfg(not(target_arch = "wasm32"))] - let mut interval = tokio_stream::wrappers::IntervalStream::new( - tokio::time::interval(self.refresh_rate), - ); - - #[cfg(target_arch = "wasm32")] - let mut interval = - gloo_timers::future::IntervalStream::new(self.refresh_rate.as_millis() as u32); - - // We already have an initial topology, so no need to refresh it immediately. - // My understanding is that js setInterval does not fire immediately, so it's not - // needed there. - #[cfg(not(target_arch = "wasm32"))] - interval.next().await; - - while !self.task_client.is_shutdown() { - tokio::select! { - _ = interval.next() => { - self.try_refresh().await; - }, - _ = self.task_client.recv() => { - tracing::trace!("TopologyRefresher: Received shutdown"); - }, - } - } - self.task_client.recv_timeout().await; - tracing::debug!("TopologyRefresher: Exiting"); - }, - "TopologyRefresher" - ) + // it's perfectly fine if task is interrupted mid-refresh + // there's no data to persist or send over + pub async fn run(&mut self) { + debug!("Started TopologyRefresher with graceful shutdown support"); + + #[cfg(not(target_arch = "wasm32"))] + let mut interval = + tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(self.refresh_rate)); + + #[cfg(target_arch = "wasm32")] + let mut interval = + gloo_timers::future::IntervalStream::new(self.refresh_rate.as_millis() as u32); + + // We already have an initial topology, so no need to refresh it immediately. + // My understanding is that js setInterval does not fire immediately, so it's not + // needed there. + #[cfg(not(target_arch = "wasm32"))] + interval.next().await; + + while interval.next().await.is_some() { + self.try_refresh().await; + } + + // this should never get triggered + error!("topology refresher interval has been exhausted!") } } diff --git a/common/client-core/src/lib.rs b/common/client-core/src/lib.rs index 8be8b101337..072bac7e9fc 100644 --- a/common/client-core/src/lib.rs +++ b/common/client-core/src/lib.rs @@ -17,7 +17,9 @@ pub use nym_topology::{ HardcodedTopologyProvider, NymRouteProvider, NymTopology, NymTopologyError, TopologyProvider, }; +#[deprecated(note = "use spawn_future from nym_task crate instead")] #[cfg(target_arch = "wasm32")] +#[track_caller] pub fn spawn_future(future: F) where F: Future + 'static, @@ -25,9 +27,7 @@ where wasm_bindgen_futures::spawn_local(future); } -// TODO: expose similar API to the rest of the codebase, -// perhaps with some simple trait for a task to define its name - +#[deprecated(note = "use spawn_future from nym_task crate instead")] #[cfg(not(target_arch = "wasm32"))] #[track_caller] pub fn spawn_future(future: F) @@ -37,35 +37,3 @@ where { tokio::spawn(future); } - -#[cfg(not(target_arch = "wasm32"))] -#[track_caller] -pub fn spawn_named_future(future: F, name: &str) -where - F: Future + Send + 'static, - F::Output: Send + 'static, -{ - cfg_if::cfg_if! {if #[cfg(tokio_unstable)] { - #[allow(clippy::expect_used)] - tokio::task::Builder::new().name(name).spawn(future).expect("failed to spawn future"); - } else { - let _ = name; - tracing::debug!(r#"the underlying binary hasn't been built with `RUSTFLAGS="--cfg tokio_unstable"` - the future naming won't do anything"#); - spawn_future(future); - }} -} - -#[macro_export] -macro_rules! spawn_future { - ($future:expr) => {{ - $crate::spawn_future($future) - }}; - ($future:expr, $name:expr) => {{ - cfg_if::cfg_if! {if #[cfg(not(target_arch = "wasm32"))] { - $crate::spawn_named_future($future, $name) - } else { - let _ = $name; - $crate::spawn_future($future) - }} - }}; -} diff --git a/common/client-core/surb-storage/src/lib.rs b/common/client-core/surb-storage/src/lib.rs index b7ac9247069..079c213bbc7 100644 --- a/common/client-core/surb-storage/src/lib.rs +++ b/common/client-core/surb-storage/src/lib.rs @@ -40,7 +40,7 @@ where pub async fn flush_on_shutdown( mut self, mem_state: CombinedReplyStorage, - mut shutdown: nym_task::TaskClient, + shutdown: nym_task::ShutdownToken, ) { use tracing::{debug, error, info}; @@ -50,7 +50,7 @@ where return; } - shutdown.recv().await; + shutdown.cancelled().await; info!("PersistentReplyStorage is flushing all reply-related data to underlying storage"); if let Err(err) = self.backend.flush_surb_storage(&mem_state).await { diff --git a/common/client-libs/gateway-client/src/client/mod.rs b/common/client-libs/gateway-client/src/client/mod.rs index de90b3ad35d..52e5833eb29 100644 --- a/common/client-libs/gateway-client/src/client/mod.rs +++ b/common/client-libs/gateway-client/src/client/mod.rs @@ -12,7 +12,7 @@ use crate::socket_state::{ws_fd, PartiallyDelegatedHandle, SocketState}; use crate::traits::GatewayPacketRouter; use crate::{cleanup_socket_message, try_decrypt_binary_message}; use futures::{SinkExt, StreamExt}; -use nym_bandwidth_controller::{BandwidthController, BandwidthStatusMessage}; +use nym_bandwidth_controller::BandwidthController; use nym_credential_storage::ephemeral_storage::EphemeralStorage as EphemeralCredentialStorage; use nym_credential_storage::storage::Storage as CredentialStorage; use nym_credentials::CredentialSpendingData; @@ -27,7 +27,7 @@ use nym_gateway_requests::{ use nym_sphinx::forwarding::packet::MixPacket; use nym_statistics_common::clients::connection::ConnectionStatsEvent; use nym_statistics_common::clients::ClientStatsSender; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use nym_validator_client::nyxd::contract_traits::DkgQueryClient; use rand::rngs::OsRng; use std::sync::Arc; @@ -109,7 +109,7 @@ pub struct GatewayClient { connection_fd_callback: Option>, /// Listen to shutdown messages and send notifications back to the task manager - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl GatewayClient { @@ -124,7 +124,7 @@ impl GatewayClient { bandwidth_controller: Option>, stats_reporter: ClientStatsSender, #[cfg(unix)] connection_fd_callback: Option>, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { GatewayClient { cfg, @@ -141,7 +141,7 @@ impl GatewayClient { negotiated_protocol: None, #[cfg(unix)] connection_fd_callback, - task_client, + shutdown_token, } } @@ -293,7 +293,7 @@ impl GatewayClient { loop { tokio::select! { - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { log::trace!("GatewayClient control response: Received shutdown"); log::debug!("GatewayClient control response: Exiting"); break Err(GatewayClientError::ConnectionClosedGatewayShutdown); @@ -514,7 +514,7 @@ impl GatewayClient { self.cfg.bandwidth.require_tickets, derive_aes256_gcm_siv_key, #[cfg(not(target_arch = "wasm32"))] - self.task_client.clone(), + self.shutdown_token.clone(), ) .await .map_err(GatewayClientError::RegistrationFailure), @@ -631,9 +631,6 @@ impl GatewayClient { self.negotiated_protocol = protocol_version; log::debug!("authenticated: {status}, bandwidth remaining: {bandwidth_remaining}"); - self.task_client.send_status_msg(Box::new( - BandwidthStatusMessage::RemainingBandwidth(bandwidth_remaining), - )); Ok(()) } ServerResponse::Error { message } => Err(GatewayClientError::GatewayError(message)), @@ -1069,7 +1066,7 @@ impl GatewayClient { .expect("no shared key present even though we're authenticated!"), ), self.bandwidth.clone(), - self.task_client.clone(), + self.shutdown_token.clone(), ) } _ => unreachable!(), @@ -1143,8 +1140,8 @@ impl GatewayClient { // perfectly fine here, because it's not meant to be used let (ack_tx, _) = mpsc::unbounded(); let (mix_tx, _) = mpsc::unbounded(); - let task_client = TaskClient::dummy(); - let packet_router = PacketRouter::new(ack_tx, mix_tx, task_client.clone()); + let shutdown_token = ShutdownToken::default(); + let packet_router = PacketRouter::new(ack_tx, mix_tx, shutdown_token.clone()); GatewayClient { cfg: GatewayClientConfig::default().with_disabled_credentials_mode(true), @@ -1157,11 +1154,11 @@ impl GatewayClient { connection: SocketState::NotConnected, packet_router, bandwidth_controller: None, - stats_reporter: ClientStatsSender::new(None, task_client.clone()), + stats_reporter: ClientStatsSender::new(None, shutdown_token.clone()), negotiated_protocol: None, #[cfg(unix)] connection_fd_callback, - task_client, + shutdown_token, } } @@ -1170,7 +1167,7 @@ impl GatewayClient { packet_router: PacketRouter, bandwidth_controller: Option>, stats_reporter: ClientStatsSender, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> GatewayClient { // invariants that can't be broken // (unless somebody decided to expose some field that wasn't meant to be exposed) @@ -1193,7 +1190,7 @@ impl GatewayClient { negotiated_protocol: self.negotiated_protocol, #[cfg(unix)] connection_fd_callback: self.connection_fd_callback, - task_client, + shutdown_token, } } } diff --git a/common/client-libs/gateway-client/src/packet_router.rs b/common/client-libs/gateway-client/src/packet_router.rs index 36168f4ab89..7fb863947f3 100644 --- a/common/client-libs/gateway-client/src/packet_router.rs +++ b/common/client-libs/gateway-client/src/packet_router.rs @@ -7,7 +7,7 @@ use crate::error::GatewayClientError; use crate::GatewayPacketRouter; use futures::channel::mpsc; -use nym_task::TaskClient; +use nym_task::ShutdownToken; pub type MixnetMessageSender = mpsc::UnboundedSender>>; pub type MixnetMessageReceiver = mpsc::UnboundedReceiver>>; @@ -19,14 +19,14 @@ pub type AcknowledgementReceiver = mpsc::UnboundedReceiver>>; pub struct PacketRouter { ack_sender: AcknowledgementSender, mixnet_message_sender: MixnetMessageSender, - shutdown: TaskClient, + shutdown: ShutdownToken, } impl PacketRouter { pub fn new( ack_sender: AcknowledgementSender, mixnet_message_sender: MixnetMessageSender, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> Self { PacketRouter { ack_sender, @@ -42,7 +42,7 @@ impl PacketRouter { if let Err(err) = self.mixnet_message_sender.unbounded_send(received_messages) { // check if the failure is due to the shutdown being in progress and thus the receiver channel // having already been dropped - if self.shutdown.is_shutdown_poll() || self.shutdown.is_dummy() { + if self.shutdown.is_cancelled() { // This should ideally not happen, but it's ok tracing::warn!("Failed to send mixnet messages due to receiver task shutdown"); return Err(GatewayClientError::ShutdownInProgress); @@ -58,7 +58,7 @@ impl PacketRouter { if let Err(err) = self.ack_sender.unbounded_send(received_acks) { // check if the failure is due to the shutdown being in progress and thus the receiver channel // having already been dropped - if self.shutdown.is_shutdown_poll() || self.shutdown.is_dummy() { + if self.shutdown.is_cancelled() { // This should ideally not happen, but it's ok tracing::warn!("Failed to send acks due to receiver task shutdown"); return Err(GatewayClientError::ShutdownInProgress); @@ -69,10 +69,6 @@ impl PacketRouter { } Ok(()) } - - pub fn disarm(&mut self) { - self.shutdown.disarm(); - } } impl GatewayPacketRouter for PacketRouter { diff --git a/common/client-libs/gateway-client/src/socket_state.rs b/common/client-libs/gateway-client/src/socket_state.rs index 5489ec36287..4f3009e3892 100644 --- a/common/client-libs/gateway-client/src/socket_state.rs +++ b/common/client-libs/gateway-client/src/socket_state.rs @@ -11,7 +11,7 @@ use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; use nym_gateway_requests::shared_key::SharedGatewayKey; use nym_gateway_requests::{SensitiveServerResponse, ServerResponse, SimpleGatewayRequestsError}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use si_scale::helpers::bibytes2; use std::os::raw::c_int as RawFd; use std::sync::Arc; @@ -87,13 +87,13 @@ impl PartiallyDelegatedRouter { } } - async fn run(mut self, mut split_stream: SplitStream, mut task_client: TaskClient) { + async fn run(mut self, mut split_stream: SplitStream, shutdown_token: ShutdownToken) { let mut chunked_stream = (&mut split_stream).ready_chunks(8); let ret: Result<_, GatewayClientError> = loop { tokio::select! { biased; // received system-wide shutdown - _ = task_client.recv() => { + _ = shutdown_token.cancelled() => { log::trace!("GatewayClient listener: Received shutdown"); log::debug!("GatewayClient listener: Exiting"); return; @@ -118,11 +118,7 @@ impl PartiallyDelegatedRouter { let return_res = match ret { Err(err) => self.stream_return.send(Err(err)), - Ok(_) => { - self.packet_router.disarm(); - task_client.disarm(); - self.stream_return.send(Ok(split_stream)) - } + Ok(_) => self.stream_return.send(Ok(split_stream)), }; if return_res.is_err() { @@ -266,8 +262,8 @@ impl PartiallyDelegatedRouter { Ok(plaintexts) } - fn spawn(self, split_stream: SplitStream, task_client: TaskClient) { - let fut = async move { self.run(split_stream, task_client).await }; + fn spawn(self, split_stream: SplitStream, shutdown_token: ShutdownToken) { + let fut = async move { self.run(split_stream, shutdown_token).await }; #[cfg(target_arch = "wasm32")] wasm_bindgen_futures::spawn_local(fut); @@ -283,7 +279,7 @@ impl PartiallyDelegatedHandle { packet_router: PacketRouter, shared_key: Arc, client_bandwidth: ClientBandwidth, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> Self { // when called for, it NEEDS TO yield back the stream so that we could merge it and // read control request responses. diff --git a/common/credential-verification/src/ecash/credential_sender.rs b/common/credential-verification/src/ecash/credential_sender.rs index 4d4a0b8fe61..9169efc61bd 100644 --- a/common/credential-verification/src/ecash/credential_sender.rs +++ b/common/credential-verification/src/ecash/credential_sender.rs @@ -126,7 +126,7 @@ pub struct CredentialHandlerConfig { pub maximum_time_between_redemption: Duration, } -pub(crate) struct CredentialHandler { +pub struct CredentialHandler { config: CredentialHandlerConfig, multisig_threshold: f32, ticket_receiver: UnboundedReceiver, @@ -907,7 +907,7 @@ impl CredentialHandler { Ok(()) } - async fn run(mut self, mut shutdown: nym_task::TaskClient) { + pub async fn run(mut self, shutdown: nym_task::ShutdownToken) { info!("Starting Ecash CredentialSender"); // attempt to clear any pending operations @@ -919,11 +919,12 @@ impl CredentialHandler { let start = Instant::now() + self.config.pending_poller; let mut resolver_interval = interval_at(start, self.config.pending_poller); - while !shutdown.is_shutdown() { + loop { tokio::select! { biased; - _ = shutdown.recv() => { + _ = shutdown.cancelled() => { trace!("client_handling::credentialSender : received shutdown"); + break }, Some(ticket) = self.ticket_receiver.next() => { let (queued_up, _) = self.ticket_receiver.size_hint(); @@ -946,8 +947,4 @@ impl CredentialHandler { } } } - - pub(crate) fn start(self, shutdown: nym_task::TaskClient) { - tokio::spawn(async move { self.run(shutdown).await }); - } } diff --git a/common/credential-verification/src/ecash/mod.rs b/common/credential-verification/src/ecash/mod.rs index 5586ef5ab1e..71d11989882 100644 --- a/common/credential-verification/src/ecash/mod.rs +++ b/common/credential-verification/src/ecash/mod.rs @@ -82,9 +82,8 @@ impl EcashManager { credential_handler_cfg: CredentialHandlerConfig, nyxd_client: DirectSigningHttpRpcNyxdClient, pk_bytes: [u8; 32], - shutdown: nym_task::TaskClient, storage: GatewayStorage, - ) -> Result { + ) -> Result<(Self, CredentialHandler), Error> { let shared_state = SharedState::new(nyxd_client, Box::new(storage)).await?; let (cred_sender, cred_receiver) = mpsc::unbounded(); @@ -92,14 +91,16 @@ impl EcashManager { let cs = CredentialHandler::new(credential_handler_cfg, cred_receiver, shared_state.clone()) .await?; - cs.start(shutdown); - - Ok(EcashManager { - shared_state, - pk_bytes, - pay_infos: Default::default(), - cred_sender, - }) + + Ok(( + EcashManager { + shared_state, + pk_bytes, + pay_infos: Default::default(), + cred_sender, + }, + cs, + )) } pub async fn verify_pay_info(&self, pay_info: NymPayInfo) -> Result { diff --git a/common/gateway-requests/src/registration/handshake/mod.rs b/common/gateway-requests/src/registration/handshake/mod.rs index f6ef8a4de41..2daed81b48e 100644 --- a/common/gateway-requests/src/registration/handshake/mod.rs +++ b/common/gateway-requests/src/registration/handshake/mod.rs @@ -14,7 +14,7 @@ use std::task::{Context, Poll}; use tungstenite::{Error as WsError, Message as WsMessage}; #[cfg(not(target_arch = "wasm32"))] -use nym_task::TaskClient; +use nym_task::ShutdownToken; pub(crate) type WsItem = Result; @@ -52,7 +52,7 @@ pub fn client_handshake<'a, S, R>( gateway_pubkey: ed25519::PublicKey, expects_credential_usage: bool, derive_aes256_gcm_siv_key: bool, - #[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient, + #[cfg(not(target_arch = "wasm32"))] shutdown_token: ShutdownToken, ) -> GatewayHandshake<'a> where S: Stream + Sink + Unpin + Send + 'a, @@ -64,7 +64,7 @@ where identity, Some(gateway_pubkey), #[cfg(not(target_arch = "wasm32"))] - shutdown, + shutdown_token, ) .with_credential_usage(expects_credential_usage) .with_aes256_gcm_siv_key(derive_aes256_gcm_siv_key); @@ -80,13 +80,13 @@ pub fn gateway_handshake<'a, S, R>( ws_stream: &'a mut S, identity: &'a ed25519::KeyPair, received_init_payload: Vec, - shutdown: TaskClient, + shutdown_token: ShutdownToken, ) -> GatewayHandshake<'a> where S: Stream + Sink + Unpin + Send + 'a, R: CryptoRng + RngCore + Send, { - let state = State::new(rng, ws_stream, identity, None, shutdown); + let state = State::new(rng, ws_stream, identity, None, shutdown_token); GatewayHandshake { handshake_future: Box::pin(state.perform_gateway_handshake(received_init_payload)), } @@ -149,7 +149,7 @@ mod tests { *gateway_keys.public_key(), false, true, - TaskClient::dummy(), + ShutdownToken::default(), ); let client_fut = handshake_client.spawn_timeboxed(); @@ -176,7 +176,7 @@ mod tests { gateway_ws, gateway_keys, init_msg, - TaskClient::dummy(), + ShutdownToken::default(), ); let gateway_fut = handshake_gateway.spawn_timeboxed(); diff --git a/common/gateway-requests/src/registration/handshake/state.rs b/common/gateway-requests/src/registration/handshake/state.rs index 62e9eba9e8c..3d7b29e3f95 100644 --- a/common/gateway-requests/src/registration/handshake/state.rs +++ b/common/gateway-requests/src/registration/handshake/state.rs @@ -24,7 +24,7 @@ use tracing::log::*; use tungstenite::Message as WsMessage; #[cfg(not(target_arch = "wasm32"))] -use nym_task::TaskClient; +use nym_task::ShutdownToken; #[cfg(not(target_arch = "wasm32"))] use tokio::time::timeout; @@ -63,7 +63,7 @@ pub(crate) struct State<'a, S, R> { // channel to receive shutdown signal #[cfg(not(target_arch = "wasm32"))] - shutdown: TaskClient, + shutdown_token: ShutdownToken, } impl<'a, S, R> State<'a, S, R> { @@ -72,7 +72,7 @@ impl<'a, S, R> State<'a, S, R> { ws_stream: &'a mut S, identity: &'a ed25519::KeyPair, remote_pubkey: Option, - #[cfg(not(target_arch = "wasm32"))] shutdown: TaskClient, + #[cfg(not(target_arch = "wasm32"))] shutdown_token: ShutdownToken, ) -> Self where R: CryptoRng + RngCore, @@ -89,7 +89,7 @@ impl<'a, S, R> State<'a, S, R> { expects_credential_usage: false, derive_aes256_gcm_siv_key: false, #[cfg(not(target_arch = "wasm32"))] - shutdown, + shutdown_token, } } @@ -306,7 +306,7 @@ impl<'a, S, R> State<'a, S, R> { loop { tokio::select! { biased; - _ = self.shutdown.recv() => return Err(HandshakeError::ReceivedShutdown), + _ = self.shutdown_token.cancelled() => return Err(HandshakeError::ReceivedShutdown), msg = self.ws_stream.next() => { let Some(ret) = Self::on_wg_msg(msg)? else { continue; diff --git a/common/node-tester-utils/src/receiver.rs b/common/node-tester-utils/src/receiver.rs index fd764d56dc9..f927e23632a 100644 --- a/common/node-tester-utils/src/receiver.rs +++ b/common/node-tester-utils/src/receiver.rs @@ -9,7 +9,7 @@ use futures::StreamExt; use nym_crypto::asymmetric::x25519; use nym_sphinx::acknowledgements::AckKey; use nym_sphinx::receiver::{MessageReceiver, SphinxMessageReceiver}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use serde::de::DeserializeOwned; use std::sync::Arc; @@ -24,7 +24,7 @@ pub struct SimpleMessageReceiver acks_receiver: mpsc::UnboundedReceiver>>, received_sender: ReceivedSender, - shutdown: TaskClient, + shutdown: ShutdownToken, } impl SimpleMessageReceiver { @@ -34,7 +34,7 @@ impl SimpleMessageReceiver { mixnet_message_receiver: mpsc::UnboundedReceiver>>, acks_receiver: mpsc::UnboundedReceiver>>, received_sender: ReceivedSender, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> Self { Self::new( local_encryption_keypair, @@ -54,7 +54,7 @@ impl SimpleMessageReceiver { mixnet_message_receiver: mpsc::UnboundedReceiver>>, acks_receiver: mpsc::UnboundedReceiver>>, received_sender: ReceivedSender, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> Self { SimpleMessageReceiver { message_processor: TestPacketProcessor::new(local_encryption_keypair, ack_key), @@ -91,11 +91,12 @@ impl SimpleMessageReceiver { where T: DeserializeOwned, { - while !self.shutdown.is_shutdown() { + loop { tokio::select! { biased; - _ = self.shutdown.recv() => { - log_info!("SimpleMessageReceiver: received shutdown") + _ = self.shutdown.cancelled() => { + log_info!("SimpleMessageReceiver: received shutdown"); + break } mixnet_messages = self.mixnet_message_receiver.next() => { let Some(mixnet_messages) = mixnet_messages else { diff --git a/common/socks5-client-core/src/lib.rs b/common/socks5-client-core/src/lib.rs index 3c7a4b658d9..034aa5f4f06 100644 --- a/common/socks5-client-core/src/lib.rs +++ b/common/socks5-client-core/src/lib.rs @@ -23,9 +23,7 @@ use nym_client_core::init::types::GatewaySetup; use nym_credential_storage::storage::Storage as CredentialStorage; use nym_sphinx::addressing::clients::Recipient; use nym_sphinx::params::PacketType; -use nym_task::{TaskClient, TaskHandle, TaskStatus}; - -use anyhow::anyhow; +use nym_task::{ShutdownManager, ShutdownTracker}; use nym_validator_client::UserAgent; use std::error::Error; use std::path::PathBuf; @@ -46,7 +44,7 @@ pub enum Socks5ControlMessage { pub struct StartedSocks5Client { /// Handle for managing graceful shutdown of this client. If dropped, the client will be stopped. - pub shutdown_handle: TaskHandle, + pub shutdown_handle: ShutdownManager, /// Address of the started client pub address: Recipient, @@ -65,6 +63,8 @@ pub struct NymClient { /// Optional path to a .json file containing standalone network details. custom_mixnet: Option, + + shutdown_manager: ShutdownManager, } impl NymClient @@ -92,6 +92,7 @@ where setup_method: GatewaySetup::MustLoad { gateway_id: None }, user_agent, custom_mixnet, + shutdown_manager: Default::default(), } } @@ -108,7 +109,7 @@ where client_output: ClientOutput, client_status: ClientState, self_address: Recipient, - shutdown: TaskClient, + shutdown: ShutdownTracker, packet_type: PacketType, ) { info!("Starting socks5 listener..."); @@ -148,51 +149,39 @@ where socks5_config.send_anonymously, socks5_config.socks5_debug, ), - shutdown.clone(), - packet_type, - ); - nym_task::spawn_with_report_error( - async move { - sphinx_socks - .serve( - input_sender, - received_buffer_request_sender, - connection_command_sender, - ) - .await - }, shutdown, + packet_type, ); + nym_task::spawn_future(async move { + sphinx_socks + .serve( + input_sender, + received_buffer_request_sender, + connection_command_sender, + ) + .await + }); } /// blocking version of `start` method. Will run forever (or until SIGINT is sent) pub async fn run_forever(self) -> Result<(), Box> { - let started = self.start().await?; + let mut started = self.start().await?; - let res = started.shutdown_handle.wait_for_shutdown().await; + started.shutdown_handle.run_until_shutdown().await; log::info!("Stopping nym-socks5-client"); - res + Ok(()) } // Variant of `run_forever` that listens for remote control messages pub async fn run_and_listen( self, mut receiver: Socks5ControlMessageReceiver, - sender: nym_task::StatusSender, ) -> Result<(), Box> { // Start the main task let started = self.start().await?; - let mut shutdown = started - .shutdown_handle - .try_into_task_manager() - .ok_or(anyhow!( - "attempted to use `run_and_listen` without owning shutdown handle" - ))?; - - // Listen to status messages from task, that we forward back to the caller - shutdown - .start_status_listener(sender, TaskStatus::Ready) - .await; + let mut task_manager = started.shutdown_handle; + + let mut shutdown_signals = task_manager.detach_shutdown_signals(); let res = tokio::select! { biased; @@ -207,22 +196,20 @@ where } } Ok(()) - } - Some(msg) = shutdown.wait_for_error() => { - log::info!("Task error: {msg:?}"); - Err(msg) - } - _ = tokio::signal::ctrl_c() => { - log::info!("Received SIGINT"); + }, + _ = shutdown_signals.wait_for_signal() => { + log::info!("Received shutdown signal"); Ok(()) }, }; - log::info!("Sending shutdown"); - shutdown.signal_shutdown().ok(); + if !task_manager.is_cancelled() { + log::info!("Sending shutdown"); + task_manager.send_cancellation(); + } log::info!("Waiting for tasks to finish... (Press ctrl-c to force)"); - shutdown.wait_for_shutdown().await; + task_manager.perform_shutdown().await; log::info!("Stopping nym-socks5-client"); res @@ -238,6 +225,7 @@ where let mut base_builder = BaseClientBuilder::new(self.config.base(), self.storage, dkg_query_client) + .with_shutdown(self.shutdown_manager.shutdown_tracker_owned()) .with_gateway_setup(self.setup_method) .with_user_agent(self.user_agent); @@ -261,7 +249,7 @@ where client_output, client_state, self_address, - started_client.task_handle.get_handle(), + self.shutdown_manager.shutdown_tracker_owned(), packet_type, ); @@ -269,7 +257,7 @@ where info!("The address of this client is: {self_address}"); Ok(StartedSocks5Client { - shutdown_handle: started_client.task_handle, + shutdown_handle: self.shutdown_manager, address: self_address, }) } diff --git a/common/socks5-client-core/src/socks/client.rs b/common/socks5-client-core/src/socks/client.rs index 5d79d663805..ffdca82a70f 100644 --- a/common/socks5-client-core/src/socks/client.rs +++ b/common/socks5-client-core/src/socks/client.rs @@ -21,7 +21,7 @@ use nym_sphinx::addressing::clients::Recipient; use nym_sphinx::params::PacketSize; use nym_sphinx::params::PacketType; use nym_task::connections::{LaneQueueLengths, TransmissionLane}; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use pin_project::pin_project; use rand::RngCore; use std::io; @@ -185,7 +185,7 @@ pub(crate) struct SocksClient { self_address: Recipient, started_proxy: bool, lane_queue_lengths: LaneQueueLengths, - shutdown_listener: TaskClient, + shutdown_listener: ShutdownTracker, packet_type: Option, } @@ -214,12 +214,9 @@ impl SocksClient { controller_sender: ControllerSender, self_address: &Recipient, lane_queue_lengths: LaneQueueLengths, - mut shutdown_listener: TaskClient, + shutdown_listener: ShutdownTracker, packet_type: Option, ) -> Self { - // If this task fails and exits, we don't want to send shutdown signal - shutdown_listener.disarm(); - let connection_id = Self::generate_random(); SocksClient { @@ -294,7 +291,6 @@ impl SocksClient { .shutdown() .await .map_err(|source| SocksProxyError::SocketShutdownFailure { source })?; - self.shutdown_listener.disarm(); Ok(()) } diff --git a/common/socks5-client-core/src/socks/mixnet_responses.rs b/common/socks5-client-core/src/socks/mixnet_responses.rs index f74681c8def..0a9e5e154d6 100644 --- a/common/socks5-client-core/src/socks/mixnet_responses.rs +++ b/common/socks5-client-core/src/socks/mixnet_responses.rs @@ -13,13 +13,13 @@ use nym_service_providers_common::interface::{ControlResponse, ResponseContent}; use nym_socks5_proxy_helpers::connection_controller::{ControllerCommand, ControllerSender}; use nym_socks5_requests::{Socks5ProviderResponse, Socks5Response, Socks5ResponseContent}; use nym_sphinx::receiver::ReconstructedMessage; -use nym_task::TaskClient; +use nym_task::ShutdownToken; pub(crate) struct MixnetResponseListener { buffer_requester: ReceivedBufferRequestSender, mix_response_receiver: ReconstructedMessagesReceiver, controller_sender: ControllerSender, - shutdown: TaskClient, + shutdown: ShutdownToken, } impl Drop for MixnetResponseListener { @@ -28,7 +28,7 @@ impl Drop for MixnetResponseListener { .buffer_requester .unbounded_send(ReceivedBufferMessage::ReceiverDisconnect) { - if self.shutdown.is_shutdown_poll() { + if self.shutdown.is_cancelled() { log::debug!("The buffer request failed: {err}"); } else { log::error!("The buffer request failed: {err}"); @@ -41,7 +41,7 @@ impl MixnetResponseListener { pub(crate) fn new( buffer_requester: ReceivedBufferRequestSender, controller_sender: ControllerSender, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> Self { let (mix_response_sender, mix_response_receiver) = mpsc::unbounded(); buffer_requester @@ -130,13 +130,18 @@ impl MixnetResponseListener { } pub(crate) async fn run(&mut self) { - while !self.shutdown.is_shutdown() { + loop { tokio::select! { + biased; + _ = self.shutdown.cancelled() => { + log::trace!("MixnetResponseListener: Received shutdown"); + break; + } received_responses = self.mix_response_receiver.next() => { if let Some(received_responses) = received_responses { for reconstructed_message in received_responses { if let Err(err) = self.on_message(reconstructed_message) { - self.shutdown.send_status_msg(Box::new(err)); + debug!("message handling error: {err}") } } } else { @@ -144,12 +149,8 @@ impl MixnetResponseListener { break; } }, - _ = self.shutdown.recv() => { - log::trace!("MixnetResponseListener: Received shutdown"); - } } } - self.shutdown.recv_timeout().await; log::debug!("MixnetResponseListener: Exiting"); } } diff --git a/common/socks5-client-core/src/socks/server.rs b/common/socks5-client-core/src/socks/server.rs index 1a3fc22f3f7..278435a9df5 100644 --- a/common/socks5-client-core/src/socks/server.rs +++ b/common/socks5-client-core/src/socks/server.rs @@ -12,7 +12,7 @@ use nym_socks5_proxy_helpers::connection_controller::Controller; use nym_sphinx::addressing::clients::Recipient; use nym_sphinx::params::PacketType; use nym_task::connections::{ConnectionCommandSender, LaneQueueLengths}; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use std::net::SocketAddr; use tap::TapFallible; use tokio::net::TcpListener; @@ -25,7 +25,7 @@ pub struct NymSocksServer { self_address: Recipient, client_config: client::Config, lane_queue_lengths: LaneQueueLengths, - shutdown: TaskClient, + shutdown: ShutdownTracker, packet_type: PacketType, } @@ -39,7 +39,7 @@ impl NymSocksServer { self_address: Recipient, lane_queue_lengths: LaneQueueLengths, client_config: client::Config, - shutdown: TaskClient, + shutdown: ShutdownTracker, packet_type: PacketType, ) -> Self { info!("Listening on {bind_address}"); @@ -72,7 +72,7 @@ impl NymSocksServer { let (mut active_streams_controller, controller_sender) = Controller::new( client_connection_tx, //BroadcastActiveConnections::Off, - self.shutdown.clone(), + self.shutdown.clone_shutdown_token(), ); tokio::spawn(async move { active_streams_controller.run().await; @@ -82,20 +82,30 @@ impl NymSocksServer { let mut mixnet_response_listener = MixnetResponseListener::new( buffer_requester, controller_sender.clone(), - self.shutdown.clone(), + self.shutdown.clone_shutdown_token(), + ); + self.shutdown.try_spawn_named( + async move { + mixnet_response_listener.run().await; + }, + "Socks5MixnetListener", ); - tokio::spawn(async move { - mixnet_response_listener.run().await; - }); // TODO:, if required, there should be another task here responsible for control requests. // it should get `input_sender` to send actual requests into the mixnet // and some channel that connects it from `MixnetResponseListener` to receive // any control responses + let shutdown = self.shutdown.clone_shutdown_token(); loop { tokio::select! { - Ok((stream, _remote)) = listener.accept() => { + biased; + _ = shutdown.cancelled() => { + log::trace!("NymSocksServer: Received shutdown"); + log::debug!("NymSocksServer: Exiting"); + return Ok(()); + } + Ok((stream, remote)) = listener.accept() => { let mut client = SocksClient::new( self.client_config, stream, @@ -109,23 +119,20 @@ impl NymSocksServer { Some(self.packet_type) ); - tokio::spawn(async move { - if let Err(err) = client.run().await { - error!("Error! {err}"); - if client.send_error(err).await.is_err() { - warn!("Failed to send error code"); + self.shutdown.try_spawn_named( + async move { + if let Err(err) = client.run().await { + error!("Error! {err}"); + if client.send_error(err).await.is_err() { + warn!("Failed to send error code"); + }; + if client.shutdown().await.is_err() { + warn!("Failed to shutdown TcpStream"); + }; }; - if client.shutdown().await.is_err() { - warn!("Failed to shutdown TcpStream"); - }; - } - }); + }, &format!("Socks5Client::{remote}") + ); }, - _ = self.shutdown.recv() => { - log::trace!("NymSocksServer: Received shutdown"); - log::debug!("NymSocksServer: Exiting"); - return Ok(()); - } } } } diff --git a/common/socks5/proxy-helpers/src/connection_controller.rs b/common/socks5/proxy-helpers/src/connection_controller.rs index 5c2e2f22c19..f8b7ca783b3 100644 --- a/common/socks5/proxy-helpers/src/connection_controller.rs +++ b/common/socks5/proxy-helpers/src/connection_controller.rs @@ -7,7 +7,7 @@ use log::*; use nym_ordered_buffer::{OrderedMessageBuffer, ReadContiguousData}; use nym_socks5_requests::{ConnectionId, SocketData}; use nym_task::connections::{ConnectionCommand, ConnectionCommandSender}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::collections::{HashMap, HashSet}; /// A generic message produced after reading from a socket/connection. @@ -101,13 +101,13 @@ pub struct Controller { // un-order messages. Note we don't ever expect to have more than 1-2 messages per connection here pending_messages: HashMap>, - shutdown: TaskClient, + shutdown: ShutdownToken, } impl Controller { pub fn new( client_connection_tx: ConnectionCommandSender, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> (Self, ControllerSender) { let (sender, receiver) = mpsc::unbounded(); ( @@ -155,7 +155,7 @@ impl Controller { .client_connection_tx .unbounded_send(ConnectionCommand::Close(conn_id)) { - if self.shutdown.is_shutdown_poll() { + if self.shutdown.is_cancelled() { log::debug!("Failed to send: {err}"); } else { log::error!("Failed to send: {err}"); @@ -230,7 +230,6 @@ impl Controller { }, } } - self.shutdown.recv_timeout().await; log::debug!("SOCKS5 Controller: Exiting"); } } diff --git a/common/socks5/proxy-helpers/src/proxy_runner/inbound.rs b/common/socks5/proxy-helpers/src/proxy_runner/inbound.rs index 6cdf0efa9cd..2e3087bbc02 100644 --- a/common/socks5/proxy-helpers/src/proxy_runner/inbound.rs +++ b/common/socks5/proxy-helpers/src/proxy_runner/inbound.rs @@ -11,7 +11,7 @@ use log::*; use nym_socks5_requests::{ConnectionId, SocketData}; use nym_task::connections::LaneQueueLengths; use nym_task::connections::TransmissionLane; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::sync::Arc; use std::time::Duration; use tokio::select; @@ -81,7 +81,7 @@ pub(super) async fn run_inbound( available_plaintext_per_mix_packet: usize, shutdown_notify: Arc, lane_queue_lengths: Option, - mut shutdown_listener: TaskClient, + shutdown_listener: ShutdownToken, ) -> OwnedReadHalf where F: Fn(SocketData) -> S + Send + 'static, @@ -129,7 +129,7 @@ where message_sender.send_empty_close().await; break; } - _ = shutdown_listener.recv() => { + _ = shutdown_listener.cancelled() => { log::trace!("ProxyRunner inbound: Received shutdown"); break; } @@ -171,6 +171,5 @@ where trace!("{connection_id} - inbound closed"); shutdown_notify.notify_one(); - shutdown_listener.disarm(); reader } diff --git a/common/socks5/proxy-helpers/src/proxy_runner/mod.rs b/common/socks5/proxy-helpers/src/proxy_runner/mod.rs index 970378d811c..75a51e366a5 100644 --- a/common/socks5/proxy-helpers/src/proxy_runner/mod.rs +++ b/common/socks5/proxy-helpers/src/proxy_runner/mod.rs @@ -5,7 +5,7 @@ use crate::connection_controller::ConnectionReceiver; use crate::ordered_sender::OrderedMessageSender; use nym_socks5_requests::{ConnectionId, SocketData}; use nym_task::connections::LaneQueueLengths; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use std::fmt::Debug; use std::{sync::Arc, time::Duration}; use tokio::{net::TcpStream, sync::Notify}; @@ -57,7 +57,8 @@ pub struct ProxyRunner { available_plaintext_per_mix_packet: usize, // Listens to shutdown commands from higher up - shutdown_listener: TaskClient, + // and spawn new tracked tasks + shutdown_tracker: ShutdownTracker, } impl ProxyRunner @@ -74,7 +75,7 @@ where available_plaintext_per_mix_packet: usize, connection_id: ConnectionId, lane_queue_lengths: Option, - shutdown_listener: TaskClient, + shutdown_tracker: ShutdownTracker, ) -> Self { ProxyRunner { mix_receiver: Some(mix_receiver), @@ -85,7 +86,7 @@ where connection_id, lane_queue_lengths, available_plaintext_per_mix_packet, - shutdown_listener, + shutdown_tracker, } } @@ -113,7 +114,7 @@ where self.available_plaintext_per_mix_packet, Arc::clone(&shutdown_notify), self.lane_queue_lengths.clone(), - self.shutdown_listener.clone(), + self.shutdown_tracker.clone_shutdown_token(), ); let outbound_future = outbound::run_outbound( @@ -123,14 +124,26 @@ where self.mix_receiver.take().unwrap(), self.connection_id, shutdown_notify, - self.shutdown_listener.clone(), + self.shutdown_tracker.clone_shutdown_token(), ); // TODO: this shouldn't really have to spawn tasks inside "library" code, but // if we used join directly, stuff would have been executed on the same thread // (it's not bad, but an unnecessary slowdown) - let handle_inbound = tokio::spawn(inbound_future); - let handle_outbound = tokio::spawn(outbound_future); + let handle_inbound = self.shutdown_tracker.try_spawn_named( + inbound_future, + &format!( + "Socks5Inbound::{}::{}", + self.remote_source_address, self.connection_id + ), + ); + let handle_outbound = self.shutdown_tracker.try_spawn_named( + outbound_future, + &format!( + "Socks5Outbound::{}::{}", + self.remote_source_address, self.connection_id + ), + ); let (inbound_result, outbound_result) = futures::future::join(handle_inbound, handle_outbound).await; @@ -148,7 +161,6 @@ where } pub fn into_inner(mut self) -> (TcpStream, ConnectionReceiver) { - self.shutdown_listener.disarm(); ( self.socket.take().unwrap(), self.mix_receiver.take().unwrap(), diff --git a/common/socks5/proxy-helpers/src/proxy_runner/outbound.rs b/common/socks5/proxy-helpers/src/proxy_runner/outbound.rs index dddae8d894f..dda46c54426 100644 --- a/common/socks5/proxy-helpers/src/proxy_runner/outbound.rs +++ b/common/socks5/proxy-helpers/src/proxy_runner/outbound.rs @@ -7,7 +7,7 @@ use futures::FutureExt; use futures::StreamExt; use log::*; use nym_socks5_requests::ConnectionId; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::{sync::Arc, time::Duration}; use tokio::io::AsyncWriteExt; use tokio::select; @@ -51,7 +51,7 @@ pub(super) async fn run_outbound( mut mix_receiver: ConnectionReceiver, connection_id: ConnectionId, shutdown_notify: Arc, - mut shutdown_listener: TaskClient, + shutdown_listener: ShutdownToken, ) -> (OwnedWriteHalf, ConnectionReceiver) { let shutdown_future = shutdown_notify.notified().then(|_| sleep(SHUTDOWN_TIMEOUT)); tokio::pin!(shutdown_future); @@ -60,6 +60,11 @@ pub(super) async fn run_outbound( loop { select! { + biased; + _ = shutdown_listener.cancelled() => { + log::trace!("ProxyRunner outbound: Received shutdown"); + break; + } connection_message = mix_receiver.next() => { if let Some(connection_message) = connection_message { if deal_with_message(connection_message, &mut writer, &local_destination_address, &remote_source_address, connection_id).await { @@ -80,16 +85,11 @@ pub(super) async fn run_outbound( debug!("closing outbound proxy after inbound was closed {SHUTDOWN_TIMEOUT:?} ago"); break; } - _ = shutdown_listener.recv() => { - log::trace!("ProxyRunner outbound: Received shutdown"); - break; - } } } trace!("{connection_id} - outbound closed"); shutdown_notify.notify_one(); - shutdown_listener.disarm(); (writer, mix_receiver) } diff --git a/common/statistics/src/clients/mod.rs b/common/statistics/src/clients/mod.rs index 2ca1fa006de..3c2badcdfe3 100644 --- a/common/statistics/src/clients/mod.rs +++ b/common/statistics/src/clients/mod.rs @@ -3,7 +3,7 @@ use crate::report::client::{ClientStatsReport, OsInformation}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use time::{OffsetDateTime, Time}; use tokio::sync::mpsc::UnboundedSender; @@ -25,18 +25,18 @@ pub type ClientStatsReceiver = tokio::sync::mpsc::UnboundedReceiver>, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl ClientStatsSender { /// Create a new statistics Sender pub fn new( stats_tx: Option>, - task_client: TaskClient, + shutdown_token: ShutdownToken, ) -> Self { ClientStatsSender { stats_tx, - task_client, + shutdown_token, } } @@ -44,7 +44,7 @@ impl ClientStatsSender { pub fn report(&self, event: ClientStatsEvents) { if let Some(tx) = &self.stats_tx { if let Err(err) = tx.send(event) { - if !self.task_client.is_shutdown_poll() { + if !self.shutdown_token.is_cancelled() { log::error!("Failed to send stats event: {err}"); } } @@ -137,8 +137,8 @@ impl ClientStatsController { self.packet_stats.snapshot(); } - pub fn local_report(&mut self, task_client: &mut TaskClient) { - self.packet_stats.local_report(task_client); + pub fn local_report(&mut self) { + self.packet_stats.local_report(); self.gateway_conn_stats.local_report(); self.nym_api_stats.local_report(); } diff --git a/common/statistics/src/clients/packet_statistics.rs b/common/statistics/src/clients/packet_statistics.rs index b2c6f56d5c5..81de7e35f60 100644 --- a/common/statistics/src/clients/packet_statistics.rs +++ b/common/statistics/src/clients/packet_statistics.rs @@ -449,15 +449,16 @@ impl PacketStatisticsControl { self.stats.clone() } - pub(crate) fn local_report(&mut self, task_client: &mut nym_task::TaskClient) { - let rates = self.report_rates(); + pub(crate) fn local_report(&mut self) { + let _rates = self.report_rates(); self.check_for_notable_events(); self.report_counters(); - // Report our current bandwidth used to e.g a GUI client - if let Some(rates) = rates { - task_client.send_status_msg(Box::new(MixnetBandwidthStatisticsEvent::new(rates))); - } + // leave the code commented in case somebody wanted to restore this logic with a different channel + // // Report our current bandwidth used to e.g a GUI client + // if let Some(rates) = rates { + // task_client.send_status_msg(Box::new(MixnetBandwidthStatisticsEvent::new(rates))); + // } } // Add the current stats to the history, and remove old ones. diff --git a/common/task/Cargo.toml b/common/task/Cargo.toml index 6047373c535..d96fcbe0686 100644 --- a/common/task/Cargo.toml +++ b/common/task/Cargo.toml @@ -30,5 +30,13 @@ workspace = true workspace = true features = ["tokio"] +[features] +tokio-tracing = ["tokio/tracing"] + [dev-dependencies] +anyhow = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "net", "signal", "test-util", "macros"] } +nym-test-utils = { path = "../test-utils" } + +[lints] +workspace = true \ No newline at end of file diff --git a/common/task/src/cancellation.rs b/common/task/src/cancellation.rs deleted file mode 100644 index ba60f79308e..00000000000 --- a/common/task/src/cancellation.rs +++ /dev/null @@ -1,414 +0,0 @@ -// Copyright 2025 - Nym Technologies SA -// SPDX-License-Identifier: Apache-2.0 - -use crate::{TaskClient, TaskManager}; -use futures::stream::FuturesUnordered; -use futures::StreamExt; -use std::future::Future; -use std::mem; -use std::ops::Deref; -use std::pin::Pin; -use std::time::Duration; -use tokio::task::JoinSet; -use tokio::time::sleep; -use tokio_util::sync::{CancellationToken, DropGuard}; -use tokio_util::task::TaskTracker; -use tracing::{debug, info, trace}; - -#[cfg(unix)] -use tokio::signal::unix::{signal, SignalKind}; - -pub const DEFAULT_MAX_SHUTDOWN_DURATION: Duration = Duration::from_secs(5); - -pub fn token_name(name: &Option) -> String { - name.clone().unwrap_or_else(|| "unknown".to_string()) -} - -// a wrapper around tokio's CancellationToken that adds optional `name` information to more easily -// track down sources of shutdown -#[derive(Debug, Default)] -pub struct ShutdownToken { - name: Option, - inner: CancellationToken, -} - -impl Clone for ShutdownToken { - fn clone(&self) -> Self { - // make sure to not accidentally overflow the stack if we keep cloning the handle - let name = if let Some(name) = &self.name { - if name != Self::OVERFLOW_NAME && name.len() < Self::MAX_NAME_LENGTH { - Some(format!("{name}-child")) - } else { - Some(Self::OVERFLOW_NAME.to_string()) - } - } else { - None - }; - - ShutdownToken { - name, - inner: self.inner.clone(), - } - } -} - -impl Deref for ShutdownToken { - type Target = CancellationToken; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl ShutdownToken { - const MAX_NAME_LENGTH: usize = 128; - const OVERFLOW_NAME: &'static str = "reached maximum ShutdownToken children name depth"; - - pub fn new(name: impl Into) -> Self { - ShutdownToken { - name: Some(name.into()), - inner: CancellationToken::new(), - } - } - - pub fn ephemeral() -> Self { - ShutdownToken::new("ephemeral-token") - } - - // Creates a ShutdownToken which will get cancelled whenever the current token gets cancelled. - // Unlike a cloned/forked ShutdownToken, cancelling a child token does not cancel the parent token. - #[must_use] - pub fn child_token>(&self, child_suffix: S) -> Self { - let suffix = child_suffix.into(); - let child_name = if let Some(base) = &self.name { - format!("{base}-{suffix}") - } else { - format!("unknown-{suffix}") - }; - - ShutdownToken { - name: Some(child_name), - inner: self.inner.child_token(), - } - } - - // Creates a clone of the ShutdownToken which will get cancelled whenever the current token gets cancelled, and vice versa. - #[must_use] - pub fn clone_with_suffix>(&self, child_suffix: S) -> Self { - let mut child = self.clone(); - let suffix = child_suffix.into(); - let child_name = if let Some(base) = &self.name { - format!("{base}-{suffix}") - } else { - format!("unknown-{suffix}") - }; - - child.name = Some(child_name); - child - } - - // exposed method with the old name for easier migration - // it will eventually be removed so please try to use `.clone_with_suffix` instead - #[must_use] - #[deprecated(note = "use .clone_with_suffix instead")] - pub fn fork>(&self, child_suffix: S) -> Self { - self.clone_with_suffix(child_suffix) - } - - // exposed method with the old name for easier migration - // it will eventually be removed so please try to use `.clone().named(name)` instead - #[must_use] - #[deprecated(note = "use .clone().named(name) instead")] - pub fn fork_named>(&self, name: S) -> Self { - self.clone().named(name) - } - - #[must_use] - pub fn named>(mut self, name: S) -> Self { - self.name = Some(name.into()); - self - } - - #[must_use] - pub fn add_suffix>(self, suffix: S) -> Self { - let suffix = suffix.into(); - let name = if let Some(base) = &self.name { - format!("{base}-{suffix}") - } else { - format!("unknown-{suffix}") - }; - self.named(name) - } - - // Returned guard will cancel this token (and all its children) on drop unless disarmed. - pub fn drop_guard(self) -> ShutdownDropGuard { - ShutdownDropGuard { - name: self.name, - inner: self.inner.drop_guard(), - } - } - - pub fn name(&self) -> String { - token_name(&self.name) - } - - pub async fn run_until_cancelled(&self, fut: F) -> Option - where - F: Future, - { - let res = self.inner.run_until_cancelled(fut).await; - trace!("'{}' got cancelled", self.name()); - res - } -} - -pub struct ShutdownDropGuard { - name: Option, - inner: DropGuard, -} - -impl Deref for ShutdownDropGuard { - type Target = DropGuard; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl ShutdownDropGuard { - pub fn disarm(self) -> ShutdownToken { - ShutdownToken { - name: self.name, - inner: self.inner.disarm(), - } - } - - pub fn name(&self) -> String { - token_name(&self.name) - } -} - -#[derive(Default)] -pub struct ShutdownSignals(JoinSet<()>); - -impl ShutdownSignals { - pub async fn wait_for_signal(&mut self) { - self.0.join_next().await; - } -} - -pub struct ShutdownManager { - pub root_token: ShutdownToken, - - legacy_task_manager: Option, - - shutdown_signals: ShutdownSignals, - - // the reason I'm not using a `JoinSet` is because it forces us to use futures with the same `::Output` type - tracker: TaskTracker, - - max_shutdown_duration: Duration, -} - -impl Deref for ShutdownManager { - type Target = TaskTracker; - - fn deref(&self) -> &Self::Target { - &self.tracker - } -} - -impl ShutdownManager { - pub fn new(root_token_name: impl Into) -> Self { - let manager = ShutdownManager { - root_token: ShutdownToken::new(root_token_name), - legacy_task_manager: None, - shutdown_signals: Default::default(), - tracker: Default::default(), - max_shutdown_duration: Duration::from_secs(10), - }; - - // we need to add an explicit watcher for the cancellation token being cancelled - // so that we could cancel all legacy tasks - let cancel_watcher = manager.root_token.clone(); - manager.with_shutdown(async move { cancel_watcher.cancelled().await }) - } - - pub fn empty_mock() -> Self { - ShutdownManager { - root_token: ShutdownToken::ephemeral(), - legacy_task_manager: None, - shutdown_signals: Default::default(), - tracker: Default::default(), - max_shutdown_duration: Default::default(), - } - } - - pub fn with_legacy_task_manager(mut self) -> Self { - let mut legacy_manager = - TaskManager::default().named(format!("{}-legacy", self.root_token.name())); - let mut legacy_error_rx = legacy_manager.task_return_error_rx(); - let mut legacy_drop_rx = legacy_manager.task_drop_rx(); - - self.legacy_task_manager = Some(legacy_manager); - - // add a task that listens for legacy task clients being dropped to trigger cancellation - self.with_shutdown(async move { - tokio::select! { - _ = legacy_error_rx.recv() => (), - _ = legacy_drop_rx.recv() => (), - } - - info!("received legacy shutdown signal"); - }) - } - - #[cfg(not(target_arch = "wasm32"))] - pub fn with_default_shutdown_signals(self) -> std::io::Result { - cfg_if::cfg_if! { - if #[cfg(unix)] { - self.with_interrupt_signal() - .with_terminate_signal()? - .with_quit_signal() - } else { - Ok(self.with_interrupt_signal()) - } - } - } - - #[must_use] - #[track_caller] - pub fn with_shutdown(mut self, shutdown: F) -> Self - where - F: Future, - F: Send + 'static, - { - let shutdown_token = self.root_token.clone(); - self.shutdown_signals.0.spawn(async move { - shutdown.await; - - info!("sending cancellation after receiving shutdown signal"); - shutdown_token.cancel(); - }); - self - } - - #[cfg(unix)] - #[track_caller] - pub fn with_shutdown_signal(self, signal_kind: SignalKind) -> std::io::Result { - let mut sig = signal(signal_kind)?; - Ok(self.with_shutdown(async move { - sig.recv().await; - })) - } - - #[cfg(not(target_arch = "wasm32"))] - #[track_caller] - pub fn with_interrupt_signal(self) -> Self { - self.with_shutdown(async move { - let _ = tokio::signal::ctrl_c().await; - }) - } - - #[cfg(unix)] - #[track_caller] - pub fn with_terminate_signal(self) -> std::io::Result { - self.with_shutdown_signal(SignalKind::terminate()) - } - - #[cfg(unix)] - #[track_caller] - pub fn with_quit_signal(self) -> std::io::Result { - self.with_shutdown_signal(SignalKind::quit()) - } - - #[must_use] - pub fn with_shutdown_duration(mut self, duration: Duration) -> Self { - self.max_shutdown_duration = duration; - self - } - - pub fn child_token>(&self, child_suffix: S) -> ShutdownToken { - self.root_token.child_token(child_suffix) - } - - pub fn clone_token>(&self, child_suffix: S) -> ShutdownToken { - self.root_token.clone_with_suffix(child_suffix) - } - - #[must_use] - pub fn subscribe_legacy>(&self, child_suffix: S) -> TaskClient { - // alternatively we could have set self.legacy_task_manager = Some(TaskManager::default()); - // on demand if it wasn't unavailable, but then we'd have to use mutable reference - #[allow(clippy::expect_used)] - self.legacy_task_manager - .as_ref() - .expect("did not enable legacy shutdown support") - .subscribe_named(child_suffix) - } - - async fn finish_shutdown(mut self) { - let mut wait_futures = FuturesUnordered::>>>::new(); - - // force shutdown via ctrl-c - wait_futures.push(Box::pin(async move { - #[cfg(not(target_arch = "wasm32"))] - let interrupt_future = tokio::signal::ctrl_c(); - - #[cfg(target_arch = "wasm32")] - let interrupt_future = futures::future::pending::<()>(); - - let _ = interrupt_future.await; - info!("received interrupt - forcing shutdown"); - })); - - // timeout - wait_futures.push(Box::pin(async move { - sleep(self.max_shutdown_duration).await; - info!("timeout reached, forcing shutdown"); - })); - - // graceful - wait_futures.push(Box::pin(async move { - self.tracker.wait().await; - debug!("migrated tasks successfully shutdown"); - if let Some(legacy) = self.legacy_task_manager.as_mut() { - legacy.wait_for_graceful_shutdown().await; - debug!("legacy tasks successfully shutdown"); - } - - info!("all registered tasks successfully shutdown") - })); - - wait_futures.next().await; - } - - pub fn detach_shutdown_signals(&mut self) -> ShutdownSignals { - mem::take(&mut self.shutdown_signals) - } - - pub fn replace_shutdown_signals(&mut self, signals: ShutdownSignals) { - self.shutdown_signals = signals; - } - - // cancellation safe - pub async fn wait_for_shutdown_signal(&mut self) { - self.shutdown_signals.0.join_next().await; - } - - pub async fn perform_shutdown(mut self) { - if let Some(legacy_manager) = self.legacy_task_manager.as_mut() { - info!("attempting to shutdown legacy tasks"); - let _ = legacy_manager.signal_shutdown(); - } - - info!("waiting for tasks to finish... (press ctrl-c to force)"); - self.finish_shutdown().await; - } - - pub async fn run_until_shutdown(mut self) { - self.wait_for_shutdown_signal().await; - - self.perform_shutdown().await; - } -} diff --git a/common/task/src/cancellation/manager.rs b/common/task/src/cancellation/manager.rs new file mode 100644 index 00000000000..f2cf5a9d028 --- /dev/null +++ b/common/task/src/cancellation/manager.rs @@ -0,0 +1,744 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::cancellation::tracker::{Cancelled, ShutdownTracker}; +use crate::spawn::JoinHandle; +use crate::ShutdownToken; +use futures::stream::FuturesUnordered; +use futures::StreamExt; +use log::error; +use std::future::Future; +use std::mem; +use std::pin::Pin; +use std::time::Duration; +use tracing::info; + +#[cfg(not(target_arch = "wasm32"))] +use tokio::time::sleep; + +#[cfg(target_arch = "wasm32")] +use wasmtimer::tokio::sleep; + +#[cfg(unix)] +use tokio::signal::unix::{signal, SignalKind}; +use tokio::task::JoinSet; + +/// A top level structure responsible for controlling process shutdown by listening to +/// the underlying registered signals and issuing cancellation to tasks derived from its root cancellation token. +#[allow(deprecated)] +pub struct ShutdownManager { + /// Optional reference to the legacy [TaskManager](crate::TaskManager) to allow easier + /// transition to the new system. + pub(crate) legacy_task_manager: Option, + + /// Registered [ShutdownSignals](ShutdownSignals) that will trigger process shutdown if detected. + pub(crate) shutdown_signals: ShutdownSignals, + + /// Combined [TaskTracker](tokio_util::task::TaskTracker) and [ShutdownToken](ShutdownToken) + /// for spawning and tracking tasks associated with this ShutdownManager. + pub(crate) tracker: ShutdownTracker, + + /// The maximum shutdown duration when tracked tasks could gracefully exit + /// before forcing the shutdown. + pub(crate) max_shutdown_duration: Duration, +} + +/// Wrapper behind futures that upon completion will trigger binary shutdown. +#[derive(Default)] +pub struct ShutdownSignals(JoinSet<()>); + +impl ShutdownSignals { + /// Wait for any of the registered signals to be ready + pub async fn wait_for_signal(&mut self) { + self.0.join_next().await; + } +} + +// note: default implementation will ONLY listen for SIGINT and will ignore SIGTERM and SIGQUIT +// this is due to result type when registering the signal +#[cfg(not(target_arch = "wasm32"))] +impl Default for ShutdownManager { + fn default() -> Self { + ShutdownManager::new_without_signals() + .with_interrupt_signal() + .with_cancel_on_panic() + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl ShutdownManager { + /// Create new instance of ShutdownManager with the most sensible defaults, so that: + /// - shutdown will be triggered upon either SIGINT, SIGTERM (unix only) or SIGQUIT (unix only) being sent + /// - shutdown will be triggered upon any task panicking + pub fn build_new_default() -> std::io::Result { + Ok(ShutdownManager::new_without_signals() + .with_default_shutdown_signals()? + .with_cancel_on_panic()) + } + + /// Register a new shutdown signal that upon completion will trigger system shutdown. + #[must_use] + #[track_caller] + pub fn with_shutdown(mut self, shutdown: F) -> Self + where + F: Future, + F: Send + 'static, + { + let shutdown_token = self.tracker.clone_shutdown_token(); + self.shutdown_signals.0.spawn(async move { + shutdown.await; + + info!("sending cancellation after receiving shutdown signal"); + shutdown_token.cancel(); + }); + self + } + + /// Include support for the legacy [TaskManager](TaskManager) to this instance of the ShutdownManager. + /// This will allow issuing [TaskClient](TaskClient) for tasks that still require them. + #[allow(deprecated)] + pub fn with_legacy_task_manager(mut self) -> Self { + let mut legacy_manager = crate::TaskManager::default().named("legacy-task-manager"); + let mut legacy_error_rx = legacy_manager.task_return_error_rx(); + let mut legacy_drop_rx = legacy_manager.task_drop_rx(); + + self.legacy_task_manager = Some(legacy_manager); + + // add a task that listens for legacy task clients being dropped to trigger cancellation + self.with_shutdown(async move { + tokio::select! { + _ = legacy_error_rx.recv() => (), + _ = legacy_drop_rx.recv() => (), + } + + info!("received legacy shutdown signal"); + }) + } + + /// Add the specified signal to the currently registered shutdown signals that will trigger + /// cancellation of all registered tasks. + #[cfg(unix)] + #[track_caller] + pub fn with_shutdown_signal(self, signal_kind: SignalKind) -> std::io::Result { + let mut sig = signal(signal_kind)?; + Ok(self.with_shutdown(async move { + sig.recv().await; + })) + } + + /// Add the SIGTERM signal to the currently registered shutdown signals that will trigger + /// cancellation of all registered tasks. + #[cfg(unix)] + #[track_caller] + pub fn with_terminate_signal(self) -> std::io::Result { + self.with_shutdown_signal(SignalKind::terminate()) + } + + /// Add the SIGQUIT signal to the currently registered shutdown signals that will trigger + /// cancellation of all registered tasks. + #[cfg(unix)] + #[track_caller] + pub fn with_quit_signal(self) -> std::io::Result { + self.with_shutdown_signal(SignalKind::quit()) + } + + /// Add default signals to the set of the currently registered shutdown signals that will trigger + /// cancellation of all registered tasks. + /// This includes SIGINT, SIGTERM and SIGQUIT for unix-based platforms and SIGINT for other targets (such as windows)/ + pub fn with_default_shutdown_signals(self) -> std::io::Result { + cfg_if::cfg_if! { + if #[cfg(unix)] { + self.with_interrupt_signal() + .with_terminate_signal()? + .with_quit_signal() + } else { + Ok(self.with_interrupt_signal()) + } + } + } + + /// Add the SIGINT (ctrl-c) signal to the currently registered shutdown signals that will trigger + /// cancellation of all registered tasks. + #[track_caller] + pub fn with_interrupt_signal(self) -> Self { + self.with_shutdown(async move { + let _ = tokio::signal::ctrl_c().await; + }) + } + + /// Spawn the provided future on the current Tokio runtime, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn(&self, task: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tracker.spawn(task) + } + + /// Spawn the provided future on the current Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// Furthermore, attach a name to the spawned task to more easily track it within a [tokio console](https://github.com/tokio-rs/console) + /// + /// Note that is no different from [spawn](Self::spawn) if the underlying binary + /// has not been built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"` + #[track_caller] + pub fn try_spawn_named(&self, task: F, name: &str) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tracker.try_spawn_named(task, name) + } + + /// Spawn the provided future on the provided Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_on(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tracker.spawn_on(task, handle) + } + + /// Spawn the provided future on the current [LocalSet](tokio::task::LocalSet), + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_local(&self, task: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + self.tracker.spawn_local(task) + } + + /// Spawn the provided blocking task on the current Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_blocking(&self, task: F) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + self.tracker.spawn_blocking(task) + } + + /// Spawn the provided blocking task on the provided Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_blocking_on(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + self.tracker.spawn_blocking_on(task, handle) + } + + /// Spawn the provided future on the current Tokio runtime + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// + /// Note that to fully use the naming feature, such as tracking within a [tokio console](https://github.com/tokio-rs/console), + /// the underlying binary has to be built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"` + #[track_caller] + pub fn try_spawn_named_with_shutdown( + &self, + task: F, + name: &str, + ) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tracker.try_spawn_named_with_shutdown(task, name) + } + + /// Spawn the provided future on the current Tokio runtime + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_with_shutdown(&self, task: F) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tracker.spawn_with_shutdown(task) + } +} + +#[cfg(target_arch = "wasm32")] +impl ShutdownManager { + /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn(&self, task: F) -> JoinHandle + where + F: Future + 'static, + { + self.tracker.spawn(task) + } + + /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// It has exactly the same behaviour as [spawn](Self::spawn) and it only exists to provide + /// the same interface as non-wasm32 targets. + #[track_caller] + pub fn try_spawn_named(&self, task: F, name: &str) -> JoinHandle + where + F: Future + 'static, + { + self.tracker.try_spawn_named(task, name) + } + + /// Run the provided future on the current thread + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// It has exactly the same behaviour as [spawn_with_shutdown](Self::spawn_with_shutdown) and it only exists to provide + /// the same interface as non-wasm32 targets. + #[track_caller] + pub fn try_spawn_named_with_shutdown( + &self, + task: F, + name: &str, + ) -> JoinHandle> + where + F: Future + Send + 'static, + { + self.tracker.try_spawn_named_with_shutdown(task, name) + } + + /// Run the provided future on the current thread + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_with_shutdown(&self, task: F) -> JoinHandle> + where + F: Future + Send + 'static, + { + self.tracker.spawn_with_shutdown(task) + } +} + +impl ShutdownManager { + /// Create new instance of ShutdownManager without any external shutdown signals registered, + /// meaning it will only attempt to wait for all tasks spawned on its tracker to gracefully finish execution. + pub fn new_without_signals() -> Self { + Self::new_from_external_shutdown_token(ShutdownToken::new()) + } + + /// Create new instance of the ShutdownManager using an external shutdown token. + /// + /// Note: it will not listen to any external shutdown signals! + /// You might want further customise it with [shutdown signals](Self::with_shutdown) + /// (or just use [the default set](Self::with_default_shutdown_signals). + /// Similarly, you might want to include [cancellation on panic](Self::with_cancel_on_panic) + /// to make sure everything gets cancelled if one of the tasks panics. + pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self { + let manager = ShutdownManager { + legacy_task_manager: None, + shutdown_signals: Default::default(), + tracker: ShutdownTracker::new_from_external_shutdown_token(shutdown_token), + max_shutdown_duration: Duration::from_secs(10), + }; + + // we need to add an explicit watcher for the cancellation token being cancelled + // so that we could cancel all legacy tasks + cfg_if::cfg_if! {if #[cfg(not(target_arch = "wasm32"))] { + let cancel_watcher = manager.tracker.clone_shutdown_token(); + manager.with_shutdown(async move { cancel_watcher.cancelled().await }) + } else { + manager + }} + } + + /// Create an empty testing mock of the ShutdownManager with no signals registered. + pub fn empty_mock() -> Self { + ShutdownManager { + legacy_task_manager: None, + shutdown_signals: Default::default(), + tracker: Default::default(), + max_shutdown_duration: Default::default(), + } + } + + /// Add additional panic hook such that upon triggering, the root [ShutdownToken](ShutdownToken) gets cancelled. + /// Note: an unfortunate limitation of this is that graceful shutdown will no longer be possible + /// since that task that has panicked will not exit and thus all shutdowns will have to be either forced + /// or will have to time out. + #[must_use] + pub fn with_cancel_on_panic(self) -> Self { + let current_hook = std::panic::take_hook(); + + let shutdown_token = self.clone_shutdown_token(); + std::panic::set_hook(Box::new(move |panic_info| { + // 1. call existing hook + current_hook(panic_info); + + let location = panic_info + .location() + .map(|l| l.to_string()) + .unwrap_or_else(|| "".to_string()); + + let payload = if let Some(payload) = panic_info.payload().downcast_ref::<&str>() { + payload + } else { + "" + }; + + // 2. issue cancellation + error!("panicked at {location}: {payload}. issuing global cancellation"); + shutdown_token.cancel(); + })); + self + } + + /// Change the maximum shutdown duration when tracked tasks could gracefully exit + /// before forcing the shutdown. + #[must_use] + pub fn with_shutdown_duration(mut self, duration: Duration) -> Self { + self.max_shutdown_duration = duration; + self + } + + /// Returns true if the root [ShutdownToken](ShutdownToken) has been cancelled. + pub fn is_cancelled(&self) -> bool { + self.tracker.root_cancellation_token.is_cancelled() + } + + /// Get a reference to the used [ShutdownTracker](ShutdownTracker) + pub fn shutdown_tracker(&self) -> &ShutdownTracker { + &self.tracker + } + + /// Get a cloned instance of the used [ShutdownTracker](ShutdownTracker) + pub fn shutdown_tracker_owned(&self) -> ShutdownTracker { + self.tracker.clone() + } + + /// Waits until the underlying [TaskTracker](tokio_util::task::TaskTracker) is both closed and empty. + /// + /// If the underlying [TaskTracker](tokio_util::task::TaskTracker) is already closed and empty when this method is called, then it + /// returns immediately. + pub async fn wait_for_tracker(&self) { + self.tracker.wait_for_tracker().await; + } + + /// Close the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// + /// This allows [`wait_for_tracker`] futures to complete. It does not prevent you from spawning new tasks. + /// + /// Returns `true` if this closed the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already closed. + /// + /// [`wait_for_tracker`]: ShutdownTracker::wait_for_tracker + pub fn close_tracker(&self) -> bool { + self.tracker.close_tracker() + } + + /// Reopen the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// + /// This prevents [`wait_for_tracker`] futures from completing even if the underlying [TaskTracker](tokio_util::task::TaskTracker) is empty. + /// + /// Returns `true` if this reopened the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already open. + /// + /// [`wait_for_tracker`]: ShutdownTracker::wait_for_tracker + pub fn reopen_tracker(&self) -> bool { + self.tracker.reopen_tracker() + } + + /// Returns `true` if the underlying [TaskTracker](tokio_util::task::TaskTracker) is [closed](Self::close_tracker). + pub fn is_tracker_closed(&self) -> bool { + self.tracker.is_tracker_closed() + } + + /// Returns the number of tasks tracked by the underlying [TaskTracker](tokio_util::task::TaskTracker). + pub fn tracked_tasks(&self) -> usize { + self.tracker.tracked_tasks() + } + + /// Returns `true` if there are no tasks in the underlying [TaskTracker](tokio_util::task::TaskTracker). + pub fn is_tracker_empty(&self) -> bool { + self.tracker.is_tracker_empty() + } + + /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) that is a child of the root token + pub fn child_shutdown_token(&self) -> ShutdownToken { + self.tracker.root_cancellation_token.child_token() + } + + /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) on the same hierarchical structure as the root token + pub fn clone_shutdown_token(&self) -> ShutdownToken { + self.tracker.root_cancellation_token.clone() + } + + /// Attempt to create a handle to a legacy [TaskClient] to support tasks that hasn't migrated + /// from the legacy [TaskManager]. + /// Note. To use this method [ShutdownManager] must be built with `.with_legacy_task_manager()` + #[must_use] + #[deprecated] + #[allow(deprecated)] + pub fn subscribe_legacy>(&self, child_suffix: S) -> crate::TaskClient { + // alternatively we could have set self.legacy_task_manager = Some(TaskManager::default()); + // on demand if it wasn't unavailable, but then we'd have to use mutable reference + #[allow(clippy::expect_used)] + self.legacy_task_manager + .as_ref() + .expect("did not enable legacy shutdown support") + .subscribe_named(child_suffix) + } + + /// Finalise the shutdown procedure by waiting until either: + /// - all tracked tasks have terminated + /// - timeout has been reached + /// - shutdown has been forced (by sending SIGINT) + async fn finish_shutdown(&mut self) { + let mut wait_futures = FuturesUnordered:: + Send>>>::new(); + + // force shutdown via ctrl-c + wait_futures.push(Box::pin(async move { + #[cfg(not(target_arch = "wasm32"))] + let interrupt_future = tokio::signal::ctrl_c(); + + #[cfg(target_arch = "wasm32")] + let interrupt_future = futures::future::pending::<()>(); + + let _ = interrupt_future.await; + info!("received interrupt - forcing shutdown"); + })); + + // timeout + let max_shutdown = self.max_shutdown_duration; + wait_futures.push(Box::pin(async move { + sleep(max_shutdown).await; + info!("timeout reached - forcing shutdown"); + })); + + // graceful + let tracker = self.tracker.clone(); + wait_futures.push(Box::pin(async move { + tracker.wait_for_tracker().await; + info!("all tracked tasks successfully shutdown"); + if let Some(legacy) = self.legacy_task_manager.as_mut() { + legacy.wait_for_graceful_shutdown().await; + info!("all legacy tasks successfully shutdown"); + } + + info!("all registered tasks successfully shutdown") + })); + + wait_futures.next().await; + } + + /// Remove the current set of [ShutdownSignals] from this instance of + /// [ShutdownManager] replacing it with an empty set. + /// + /// This is potentially useful if one wishes to start listening for the signals + /// before the whole process has been fully set up. + pub fn detach_shutdown_signals(&mut self) -> ShutdownSignals { + mem::take(&mut self.shutdown_signals) + } + + /// Replace the current set of [ShutdownSignals] used for determining + /// whether the underlying process should be stopped. + pub fn replace_shutdown_signals(&mut self, signals: ShutdownSignals) { + self.shutdown_signals = signals; + } + + /// Send cancellation signal to all registered tasks by cancelling the root token + /// and sending shutdown signal, if applicable, on the legacy [TaskManager] + pub fn send_cancellation(&self) { + if let Some(legacy_manager) = self.legacy_task_manager.as_ref() { + info!("attempting to shutdown legacy tasks"); + let _ = legacy_manager.signal_shutdown(); + } + self.tracker.root_cancellation_token.cancel(); + } + + /// Wait until receiving one of the registered shutdown signals + /// this method is cancellation safe + pub async fn wait_for_shutdown_signal(&mut self) { + #[cfg(not(target_arch = "wasm32"))] + self.shutdown_signals.0.join_next().await; + + #[cfg(target_arch = "wasm32")] + self.tracker.root_cancellation_token.cancelled().await; + } + + /// Perform system shutdown by sending relevant signals and waiting until either: + /// - all tracked tasks have terminated + /// - timeout has been reached + /// - shutdown has been forced (by sending SIGINT) + pub async fn perform_shutdown(&mut self) { + self.send_cancellation(); + + info!("waiting for tasks to finish... (press ctrl-c to force)"); + self.finish_shutdown().await; + } + + /// Wait until a shutdown signal has been received and trigger system shutdown. + pub async fn run_until_shutdown(&mut self) { + self.close_tracker(); + self.wait_for_shutdown_signal().await; + + self.perform_shutdown().await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use nym_test_utils::traits::{ElapsedExt, Timeboxed}; + use std::sync::atomic::AtomicBool; + use std::sync::Arc; + + #[tokio::test] + async fn shutdown_with_no_tracked_tasks_and_signals() -> anyhow::Result<()> { + let mut manager = ShutdownManager::new_without_signals(); + let res = manager.run_until_shutdown().timeboxed().await; + assert!(res.has_elapsed()); + + let mut manager = ShutdownManager::new_without_signals(); + let shutdown = manager.clone_shutdown_token(); + shutdown.cancel(); + let res = manager.run_until_shutdown().timeboxed().await; + assert!(!res.has_elapsed()); + + Ok(()) + } + + #[tokio::test] + async fn shutdown_signal() -> anyhow::Result<()> { + let timeout_shutdown = sleep(Duration::from_millis(100)); + let mut manager = ShutdownManager::new_without_signals().with_shutdown(timeout_shutdown); + + // execution finishes after the sleep gets finishes + let res = manager + .run_until_shutdown() + .execute_with_deadline(Duration::from_millis(200)) + .await; + assert!(!res.has_elapsed()); + + Ok(()) + } + + #[tokio::test] + async fn panic_hook() -> anyhow::Result<()> { + let mut manager = ShutdownManager::new_without_signals().with_cancel_on_panic(); + manager.spawn_with_shutdown(async move { + sleep(Duration::from_millis(10000)).await; + }); + manager.spawn_with_shutdown(async move { + sleep(Duration::from_millis(10)).await; + panic!("panicking"); + }); + + // execution finishes after the panic gets triggered + let res = manager + .run_until_shutdown() + .execute_with_deadline(Duration::from_millis(200)) + .await; + assert!(!res.has_elapsed()); + + Ok(()) + } + + #[tokio::test] + async fn task_cancellation() -> anyhow::Result<()> { + let timeout_shutdown = sleep(Duration::from_millis(100)); + let mut manager = ShutdownManager::new_without_signals().with_shutdown(timeout_shutdown); + + let cancelled1 = Arc::new(AtomicBool::new(false)); + let cancelled1_clone = cancelled1.clone(); + let cancelled2 = Arc::new(AtomicBool::new(false)); + let cancelled2_clone = cancelled2.clone(); + + let shutdown = manager.clone_shutdown_token(); + manager.spawn(async move { + shutdown.cancelled().await; + cancelled1_clone.store(true, std::sync::atomic::Ordering::Relaxed); + }); + + let shutdown = manager.clone_shutdown_token(); + manager.spawn(async move { + shutdown.cancelled().await; + cancelled2_clone.store(true, std::sync::atomic::Ordering::Relaxed); + }); + + let res = manager + .run_until_shutdown() + .execute_with_deadline(Duration::from_millis(200)) + .await; + + assert!(!res.has_elapsed()); + assert!(cancelled1.load(std::sync::atomic::Ordering::Relaxed)); + assert!(cancelled2.load(std::sync::atomic::Ordering::Relaxed)); + Ok(()) + } + + #[tokio::test] + async fn cancellation_within_task() -> anyhow::Result<()> { + let mut manager = ShutdownManager::new_without_signals(); + + let cancelled1 = Arc::new(AtomicBool::new(false)); + let cancelled1_clone = cancelled1.clone(); + + let shutdown = manager.clone_shutdown_token(); + manager.spawn(async move { + shutdown.cancelled().await; + cancelled1_clone.store(true, std::sync::atomic::Ordering::Relaxed); + }); + + let shutdown = manager.clone_shutdown_token(); + manager.spawn(async move { + sleep(Duration::from_millis(10)).await; + shutdown.cancel(); + }); + + let res = manager + .run_until_shutdown() + .execute_with_deadline(Duration::from_millis(200)) + .await; + + assert!(!res.has_elapsed()); + assert!(cancelled1.load(std::sync::atomic::Ordering::Relaxed)); + Ok(()) + } + + #[tokio::test] + async fn shutdown_timeout() -> anyhow::Result<()> { + let timeout_shutdown = sleep(Duration::from_millis(50)); + let mut manager = ShutdownManager::new_without_signals() + .with_shutdown(timeout_shutdown) + .with_shutdown_duration(Duration::from_millis(1000)); + + // ignore shutdown signals + manager.spawn(async move { + sleep(Duration::from_millis(1000)).await; + }); + + let res = manager + .run_until_shutdown() + .execute_with_deadline(Duration::from_millis(200)) + .await; + + assert!(res.has_elapsed()); + + let timeout_shutdown = sleep(Duration::from_millis(50)); + let mut manager = ShutdownManager::new_without_signals() + .with_shutdown(timeout_shutdown) + .with_shutdown_duration(Duration::from_millis(100)); + + // ignore shutdown signals + manager.spawn(async move { + sleep(Duration::from_millis(1000)).await; + }); + + let res = manager + .run_until_shutdown() + .execute_with_deadline(Duration::from_millis(200)) + .await; + + assert!(!res.has_elapsed()); + Ok(()) + } +} diff --git a/common/task/src/cancellation/mod.rs b/common/task/src/cancellation/mod.rs new file mode 100644 index 00000000000..fc9c5a648b6 --- /dev/null +++ b/common/task/src/cancellation/mod.rs @@ -0,0 +1,54 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +//! A [CancellationToken](tokio_util::sync::CancellationToken)-backed shutdown mechanism for Nym binaries. +//! +//! It allows creation of a centralised manager for keeping track of all signals that are meant +//! to trigger exit of all associated tasks and sending cancellation to the aforementioned futures. +//! +//! # Default usage +//! +//! ```no_run +//! use std::time::Duration; +//! use tokio::time::sleep; +//! use nym_task::{ShutdownManager, ShutdownToken}; +//! +//! async fn my_task() { +//! loop { +//! sleep(Duration::from_secs(5)).await +//! // do some periodic work that can be easily interrupted +//! } +//! } +//! +//! async fn important_work_that_cant_be_interrupted() {} +//! +//! async fn my_managed_task(shutdown_token: ShutdownToken) { +//! tokio::select! { +//! _ = shutdown_token.cancelled() => {} +//! _ = important_work_that_cant_be_interrupted() => {} +//! } +//! } +//! #[tokio::main] +//! async fn main() { +//! let mut shutdown_manager = ShutdownManager::build_new_default().expect("failed to register default shutdown signals"); +//! +//! let shutdown_token = shutdown_manager.child_shutdown_token(); +//! shutdown_manager.try_spawn_named(async move { my_managed_task(shutdown_token).await }, "important-managed-task"); +//! shutdown_manager.try_spawn_named_with_shutdown(my_task(), "another-task"); +//! +//! // wait for shutdown signal +//! shutdown_manager.run_until_shutdown().await; +//! } +//! ``` + +use std::time::Duration; + +pub mod manager; +pub mod token; +pub mod tracker; + +pub use manager::ShutdownManager; +pub use token::{ShutdownDropGuard, ShutdownToken}; +pub use tracker::ShutdownTracker; + +pub const DEFAULT_MAX_SHUTDOWN_DURATION: Duration = Duration::from_secs(5); diff --git a/common/task/src/cancellation/token.rs b/common/task/src/cancellation/token.rs new file mode 100644 index 00000000000..ec553ed2167 --- /dev/null +++ b/common/task/src/cancellation/token.rs @@ -0,0 +1,150 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::event::SentStatus; +use std::future::Future; +use tokio_util::sync::{ + CancellationToken, DropGuard, WaitForCancellationFuture, WaitForCancellationFutureOwned, +}; +use tracing::warn; + +/// A wrapped [CancellationToken](tokio_util::sync::CancellationToken) that is used for +/// signalling and listening for cancellation requests. +// We don't use CancellationToken in case we wanted to include additional fields/methods +// down the line. +#[derive(Debug, Clone, Default)] +pub struct ShutdownToken { + inner: CancellationToken, +} + +impl ShutdownToken { + /// A drop in no-op replacement for `send_status_msg` for easier migration from [TaskClient](crate::TaskClient). + #[deprecated] + #[track_caller] + pub fn send_status_msg(&self, status: SentStatus) { + let caller = std::panic::Location::caller(); + warn!("{caller} attempted to send {status} - there are no more listeners of those"); + } + + /// Creates a new ShutdownToken in the non-cancelled state. + pub fn new() -> Self { + ShutdownToken { + inner: CancellationToken::new(), + } + } + + /// Gets reference to the underlying [CancellationToken](tokio_util::sync::CancellationToken). + pub fn inner(&self) -> &CancellationToken { + &self.inner + } + + /// Creates a `ShutdownToken` which will get cancelled whenever the + /// current token gets cancelled. Unlike a cloned `ShutdownToken`, + /// cancelling a child token does not cancel the parent token. + /// + /// If the current token is already cancelled, the child token will get + /// returned in cancelled state. + pub fn child_token(&self) -> ShutdownToken { + ShutdownToken { + inner: self.inner.child_token(), + } + } + + /// Cancel the underlying [CancellationToken](tokio_util::sync::CancellationToken) and all child tokens which had been + /// derived from it. + /// + /// This will wake up all tasks which are waiting for cancellation. + pub fn cancel(&self) { + self.inner.cancel(); + } + + /// Returns `true` if the underlying [CancellationToken](tokio_util::sync::CancellationToken) is cancelled. + pub fn is_cancelled(&self) -> bool { + self.inner.is_cancelled() + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled(&self) -> WaitForCancellationFuture<'_> { + self.inner.cancelled() + } + + /// Returns a `Future` that gets fulfilled when cancellation is requested. + /// + /// The future will complete immediately if the token is already cancelled + /// when this method is called. + /// + /// The function takes self by value and returns a future that owns the + /// token. + /// + /// # Cancel safety + /// + /// This method is cancel safe. + pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned { + self.inner.cancelled_owned() + } + + /// Creates a `ShutdownDropGuard` for this token. + /// + /// Returned guard will cancel this token (and all its children) on drop + /// unless disarmed. + pub fn drop_guard(self) -> ShutdownDropGuard { + ShutdownDropGuard { + inner: self.inner.drop_guard(), + } + } + + /// Runs a future to completion and returns its result wrapped inside an `Option` + /// unless the `ShutdownToken` is cancelled. In that case the function returns + /// `None` and the future gets dropped. + /// + /// # Cancel safety + /// + /// This method is only cancel safe if `fut` is cancel safe. + pub async fn run_until_cancelled(&self, fut: F) -> Option + where + F: Future, + { + self.inner.run_until_cancelled(fut).await + } + + /// Runs a future to completion and returns its result wrapped inside an `Option` + /// unless the `ShutdownToken` is cancelled. In that case the function returns + /// `None` and the future gets dropped. + /// + /// The function takes self by value and returns a future that owns the token. + /// + /// # Cancel safety + /// + /// This method is only cancel safe if `fut` is cancel safe. + pub async fn run_until_cancelled_owned(self, fut: F) -> Option + where + F: Future, + { + self.inner.run_until_cancelled_owned(fut).await + } +} + +/// A wrapper for [DropGuard](tokio_util::sync::DropGuard) that wraps around a cancellation token +/// which automatically cancels it on drop. +/// It is created using `drop_guard` method on the `ShutdownToken`. +pub struct ShutdownDropGuard { + inner: DropGuard, +} + +impl ShutdownDropGuard { + /// Returns stored [ShutdownToken](ShutdownToken) and removes this drop guard instance + /// (i.e. it will no longer cancel token). Other guards for this token + /// are not affected. + pub fn disarm(self) -> ShutdownToken { + ShutdownToken { + inner: self.inner.disarm(), + } + } +} diff --git a/common/task/src/cancellation/tracker.rs b/common/task/src/cancellation/tracker.rs new file mode 100644 index 00000000000..c48b9c71ff3 --- /dev/null +++ b/common/task/src/cancellation/tracker.rs @@ -0,0 +1,317 @@ +// Copyright 2025 - Nym Technologies SA +// SPDX-License-Identifier: Apache-2.0 + +use crate::cancellation::token::ShutdownToken; +use crate::spawn::{spawn_named_future, JoinHandle}; +use crate::spawn_future; +use std::future::Future; +use thiserror::Error; +use tokio_util::task::TaskTracker; +use tracing::{debug, trace}; + +#[derive(Debug, Error)] +#[error("task got cancelled")] +pub struct Cancelled; + +/// Extracted [TaskTracker](tokio_util::task::TaskTracker) and [ShutdownToken](ShutdownToken) to more easily allow tracking nested tasks +/// without having to pass whole [ShutdownManager](ShutdownManager) around. +#[derive(Clone, Default, Debug)] +pub struct ShutdownTracker { + /// The root [ShutdownToken](ShutdownToken) that will trigger all derived tasks + /// to receive cancellation signal. + pub(crate) root_cancellation_token: ShutdownToken, + + // Note: the reason we're not using a `JoinSet` is + // because it forces us to use futures with the same `::Output` type, + // which is not really a desirable property in this instance. + /// Tracker used for keeping track of all registered tasks + /// so that they could be stopped gracefully before ending the process. + pub(crate) tracker: TaskTracker, +} + +#[cfg(not(target_arch = "wasm32"))] +impl ShutdownTracker { + /// Spawn the provided future on the current Tokio runtime, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn(&self, task: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let tracked = self.tracker.track_future(task); + spawn_future(tracked) + } + + /// Spawn the provided future on the current Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// Furthermore, attach a name to the spawned task to more easily track it within a [tokio console](https://github.com/tokio-rs/console) + /// + /// Note that is no different from [spawn](Self::spawn) if the underlying binary + /// has not been built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"` + #[track_caller] + pub fn try_spawn_named(&self, task: F, name: &str) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + trace!("attempting to spawn task {name}"); + let tracked = self.tracker.track_future(task); + spawn_named_future(tracked, name) + } + + /// Spawn the provided future on the provided Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_on(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tracker.spawn_on(task, handle) + } + + /// Spawn the provided future on the current [LocalSet](tokio::task::LocalSet), + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_local(&self, task: F) -> JoinHandle + where + F: Future + 'static, + F::Output: 'static, + { + self.tracker.spawn_local(task) + } + + /// Spawn the provided blocking task on the current Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_blocking(&self, task: F) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + self.tracker.spawn_blocking(task) + } + + /// Spawn the provided blocking task on the provided Tokio runtime, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_blocking_on(&self, task: F, handle: &tokio::runtime::Handle) -> JoinHandle + where + F: FnOnce() -> T, + F: Send + 'static, + T: Send + 'static, + { + self.tracker.spawn_blocking_on(task, handle) + } + + /// Spawn the provided future on the current Tokio runtime + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// + /// Note that to fully use the naming feature, such as tracking within a [tokio console](https://github.com/tokio-rs/console), + /// the underlying binary has to be built with `RUSTFLAGS="--cfg tokio_unstable"` and `--features="tokio-tracing"` + #[track_caller] + pub fn try_spawn_named_with_shutdown( + &self, + task: F, + name: &str, + ) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + trace!("attempting to spawn task {name} (with top-level cancellation)"); + + let caller = std::panic::Location::caller(); + let shutdown_token = self.clone_shutdown_token(); + let name_owned = name.to_string(); + let tracked = self.tracker.track_future(async move { + match shutdown_token.run_until_cancelled_owned(task).await { + Some(result) => { + debug!("{name_owned} @ {caller}: task has finished execution"); + Ok(result) + } + None => { + trace!("{name_owned} @ {caller}: shutdown signal received, shutting down"); + Err(Cancelled) + } + } + }); + spawn_named_future(tracked, name) + } + + /// Spawn the provided future on the current Tokio runtime + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_with_shutdown(&self, task: F) -> JoinHandle> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let caller = std::panic::Location::caller(); + let shutdown_token = self.clone_shutdown_token(); + self.tracker.spawn(async move { + match shutdown_token.run_until_cancelled_owned(task).await { + Some(result) => { + debug!("{caller}: task has finished execution"); + Ok(result) + } + None => { + trace!("{caller}: shutdown signal received, shutting down"); + Err(Cancelled) + } + } + }) + } +} + +#[cfg(target_arch = "wasm32")] +impl ShutdownTracker { + /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn(&self, task: F) -> JoinHandle + where + F: Future + 'static, + { + let tracked = self.tracker.track_future(task); + spawn_future(tracked) + } + + /// Run the provided future on the current thread, and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// It has exactly the same behaviour as [spawn](Self::spawn) and it only exists to provide + /// the same interface as non-wasm32 targets. + #[track_caller] + pub fn try_spawn_named(&self, task: F, name: &str) -> JoinHandle + where + F: Future + 'static, + { + let tracked = self.tracker.track_future(task); + spawn_named_future(tracked, name) + } + + /// Run the provided future on the current thread + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// It has exactly the same behaviour as [spawn_with_shutdown](Self::spawn_with_shutdown) and it only exists to provide + /// the same interface as non-wasm32 targets. + #[track_caller] + pub fn try_spawn_named_with_shutdown( + &self, + task: F, + name: &str, + ) -> JoinHandle> + where + F: Future + 'static, + { + let caller = std::panic::Location::caller(); + let shutdown_token = self.clone_shutdown_token(); + let tracked = self.tracker.track_future(async move { + match shutdown_token.run_until_cancelled_owned(task).await { + Some(result) => { + debug!("{caller}: task has finished execution"); + Ok(result) + } + None => { + trace!("{caller}: shutdown signal received, shutting down"); + Err(Cancelled) + } + } + }); + spawn_named_future(tracked, name) + } + + /// Run the provided future on the current thread + /// that will get cancelled once a global shutdown signal is detected, + /// and track it in the underlying [TaskTracker](tokio_util::task::TaskTracker). + #[track_caller] + pub fn spawn_with_shutdown(&self, task: F) -> JoinHandle> + where + F: Future + 'static, + { + let caller = std::panic::Location::caller(); + let shutdown_token = self.clone_shutdown_token(); + let tracked = self.tracker.track_future(async move { + match shutdown_token.run_until_cancelled_owned(task).await { + Some(result) => { + debug!("{caller}: task has finished execution"); + Ok(result) + } + None => { + trace!("{caller}: shutdown signal received, shutting down"); + Err(Cancelled) + } + } + }); + spawn_future(tracked) + } +} + +impl ShutdownTracker { + /// Create new instance of the ShutdownTracker using an external shutdown token. + /// This could be useful in situations where shutdown is being managed by an external entity + /// that is not [ShutdownManager](ShutdownManager), but interface requires providing a ShutdownTracker, + /// such as client-core tasks + pub fn new_from_external_shutdown_token(shutdown_token: ShutdownToken) -> Self { + ShutdownTracker { + root_cancellation_token: shutdown_token, + tracker: Default::default(), + } + } + + /// Waits until the underlying [TaskTracker](tokio_util::task::TaskTracker) is both closed and empty. + /// + /// If the underlying [TaskTracker](tokio_util::task::TaskTracker) is already closed and empty when this method is called, then it + /// returns immediately. + pub async fn wait_for_tracker(&self) { + self.tracker.wait().await; + } + + /// Close the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// + /// This allows [`wait_for_tracker`] futures to complete. It does not prevent you from spawning new tasks. + /// + /// Returns `true` if this closed the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already closed. + /// + /// [`wait_for_tracker`]: Self::wait_for_tracker + pub fn close_tracker(&self) -> bool { + self.tracker.close() + } + + /// Reopen the underlying [TaskTracker](tokio_util::task::TaskTracker). + /// + /// This prevents [`wait_for_tracker`] futures from completing even if the underlying [TaskTracker](tokio_util::task::TaskTracker) is empty. + /// + /// Returns `true` if this reopened the underlying [TaskTracker](tokio_util::task::TaskTracker), or `false` if it was already open. + /// + /// [`wait_for_tracker`]: Self::wait_for_tracker + pub fn reopen_tracker(&self) -> bool { + self.tracker.reopen() + } + + /// Returns `true` if the underlying [TaskTracker](tokio_util::task::TaskTracker) is [closed](Self::close_tracker). + pub fn is_tracker_closed(&self) -> bool { + self.tracker.is_closed() + } + + /// Returns the number of tasks tracked by the underlying [TaskTracker](tokio_util::task::TaskTracker). + pub fn tracked_tasks(&self) -> usize { + self.tracker.len() + } + + /// Returns `true` if there are no tasks in the underlying [TaskTracker](tokio_util::task::TaskTracker). + pub fn is_tracker_empty(&self) -> bool { + self.tracker.is_empty() + } + + /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) that is a child of the root token + pub fn child_shutdown_token(&self) -> ShutdownToken { + self.root_cancellation_token.child_token() + } + + /// Obtain a [ShutdownToken](crate::cancellation::ShutdownToken) on the same hierarchical structure as the root token + pub fn clone_shutdown_token(&self) -> ShutdownToken { + self.root_cancellation_token.clone() + } +} diff --git a/common/task/src/connections.rs b/common/task/src/connections.rs index 35f448c6221..b0e8fc6ecca 100644 --- a/common/task/src/connections.rs +++ b/common/task/src/connections.rs @@ -2,12 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use futures::channel::mpsc; -use std::{ - collections::HashMap, - time::{Duration, Instant}, -}; +use std::collections::HashMap; -const LANE_CONSIDERED_CLEAR: usize = 10; +// const LANE_CONSIDERED_CLEAR: usize = 10; pub type ConnectionId = u64; @@ -83,21 +80,21 @@ impl LaneQueueLengths { } } - pub async fn wait_until_clear(&self, lane: &TransmissionLane, timeout: Option) { - let total_time_waited = Instant::now(); - loop { - let lane_length = self.get(lane).unwrap_or_default(); - if lane_length < LANE_CONSIDERED_CLEAR { - break; - } - if timeout.is_some_and(|timeout| total_time_waited.elapsed() > timeout) { - log::warn!("Timeout reached while waiting for queue to clear"); - break; - } - log::trace!("Waiting for queue to clear ({lane_length} items left)"); - tokio::time::sleep(Duration::from_millis(100)).await; - } - } + // pub async fn wait_until_clear(&self, lane: &TransmissionLane, timeout: Option) { + // let total_time_waited = Instant::now(); + // loop { + // let lane_length = self.get(lane).unwrap_or_default(); + // if lane_length < LANE_CONSIDERED_CLEAR { + // break; + // } + // if timeout.is_some_and(|timeout| total_time_waited.elapsed() > timeout) { + // log::warn!("Timeout reached while waiting for queue to clear"); + // break; + // } + // log::trace!("Waiting for queue to clear ({lane_length} items left)"); + // tokio::time::sleep(Duration::from_millis(100)).await; + // } + // } } impl Default for LaneQueueLengths { diff --git a/common/task/src/lib.rs b/common/task/src/lib.rs index edbf0afef47..7be94f6c466 100644 --- a/common/task/src/lib.rs +++ b/common/task/src/lib.rs @@ -9,10 +9,11 @@ pub mod manager; pub mod signal; pub mod spawn; -pub use cancellation::{ShutdownDropGuard, ShutdownManager, ShutdownToken}; +pub use cancellation::{ShutdownDropGuard, ShutdownManager, ShutdownToken, ShutdownTracker}; pub use event::{StatusReceiver, StatusSender, TaskStatus, TaskStatusEvent}; -pub use manager::{TaskClient, TaskHandle, TaskManager}; -pub use spawn::{spawn, spawn_with_report_error}; +#[allow(deprecated)] +pub use manager::{TaskClient, TaskManager}; +pub use spawn::spawn_future; pub use tokio_util::task::TaskTracker; #[cfg(not(target_arch = "wasm32"))] diff --git a/common/task/src/manager.rs b/common/task/src/manager.rs index d5af78fcc05..5a6d22bfb61 100644 --- a/common/task/src/manager.rs +++ b/common/task/src/manager.rs @@ -44,6 +44,7 @@ enum TaskError { /// Listens to status and error messages from tasks, as well as notifying them to gracefully /// shutdown. Keeps track of if task stop unexpectedly, such as in a panic. +#[deprecated(note = "use ShutdownManager instead")] #[derive(Debug)] pub struct TaskManager { // optional name assigned to the task manager that all subscribed task clients will inherit @@ -72,6 +73,7 @@ pub struct TaskManager { task_status_rx: Option, } +#[allow(deprecated)] impl Default for TaskManager { fn default() -> Self { let (notify_tx, notify_rx) = watch::channel(()); @@ -95,6 +97,8 @@ impl Default for TaskManager { } } +#[allow(deprecated)] +#[allow(clippy::expect_used)] impl TaskManager { pub fn new(shutdown_timer_secs: u64) -> Self { Self { @@ -168,7 +172,7 @@ impl TaskManager { if let Some(mut task_status_rx) = self.task_status_rx.take() { log::info!("Starting status message listener"); - crate::spawn::spawn(async move { + crate::spawn::spawn_future(async move { loop { if let Some(msg) = task_status_rx.next().await { log::trace!("Got msg: {msg}"); @@ -186,12 +190,14 @@ impl TaskManager { } // used for compatibility with the ShutdownManager + #[cfg(not(target_arch = "wasm32"))] pub(crate) fn task_return_error_rx(&mut self) -> ErrorReceiver { self.task_return_error_rx .take() .expect("unable to get error channel: attempt to wait twice?") } + #[cfg(not(target_arch = "wasm32"))] pub(crate) fn task_drop_rx(&mut self) -> ErrorReceiver { self.task_drop_rx .take() @@ -259,6 +265,7 @@ impl TaskManager { /// Listen for shutdown notifications, and can send error and status messages back to the /// `TaskManager` #[derive(Debug)] +#[deprecated(note = "use ShutdownToken instead")] pub struct TaskClient { // optional name assigned to the shutdown handle name: Option, @@ -286,6 +293,7 @@ pub struct TaskClient { mode: ClientOperatingMode, } +#[allow(deprecated)] impl Clone for TaskClient { fn clone(&self) -> Self { // make sure to not accidentally overflow the stack if we keep cloning the handle @@ -313,6 +321,7 @@ impl Clone for TaskClient { } } +#[allow(deprecated)] impl TaskClient { const MAX_NAME_LENGTH: usize = 128; const OVERFLOW_NAME: &'static str = "reached maximum TaskClient children name depth"; @@ -433,6 +442,8 @@ impl TaskClient { .await } + // legacy code + #[allow(clippy::panic)] pub async fn recv_timeout(&mut self) { if self.mode.is_dummy() { return pending().await; @@ -505,6 +516,7 @@ impl TaskClient { } } +#[allow(deprecated)] impl Drop for TaskClient { fn drop(&mut self) { if !self.mode.should_signal_on_drop() { @@ -572,6 +584,8 @@ impl ClientOperatingMode { } } +#[deprecated] +#[allow(deprecated)] #[derive(Debug)] pub enum TaskHandle { /// Full [`TaskManager`] that was created by the underlying task. @@ -581,24 +595,28 @@ pub enum TaskHandle { External(TaskClient), } +#[allow(deprecated)] impl From for TaskHandle { fn from(value: TaskManager) -> Self { TaskHandle::Internal(value) } } +#[allow(deprecated)] impl From for TaskHandle { fn from(value: TaskClient) -> Self { TaskHandle::External(value) } } +#[allow(deprecated)] impl Default for TaskHandle { fn default() -> Self { TaskHandle::Internal(TaskManager::default()) } } +#[allow(deprecated)] impl TaskHandle { #[must_use] pub fn name_if_unnamed>(self, name: S) -> Self { @@ -666,6 +684,7 @@ mod tests { use super::*; #[tokio::test] + #[allow(deprecated)] async fn signal_shutdown() { let shutdown = TaskManager::default(); let mut listener = shutdown.subscribe(); diff --git a/common/task/src/signal.rs b/common/task/src/signal.rs index ebaab7b20fe..a8d7f6865dd 100644 --- a/common/task/src/signal.rs +++ b/common/task/src/signal.rs @@ -1,6 +1,7 @@ -use crate::{manager::SentError, TaskManager}; +use crate::manager::SentError; #[cfg(unix)] +#[allow(clippy::expect_used)] pub async fn wait_for_signal() { use tokio::signal::unix::{signal, SignalKind}; let mut sigterm = signal(SignalKind::terminate()).expect("Failed to setup SIGTERM channel"); @@ -28,8 +29,10 @@ pub async fn wait_for_signal() { } } +#[allow(deprecated)] #[cfg(unix)] -pub async fn wait_for_signal_and_error(shutdown: &mut TaskManager) -> Result<(), SentError> { +#[allow(clippy::expect_used)] +pub async fn wait_for_signal_and_error(shutdown: &mut crate::TaskManager) -> Result<(), SentError> { use tokio::signal::unix::{signal, SignalKind}; let mut sigterm = signal(SignalKind::terminate()).expect("Failed to setup SIGTERM channel"); @@ -55,8 +58,9 @@ pub async fn wait_for_signal_and_error(shutdown: &mut TaskManager) -> Result<(), } } +#[allow(deprecated)] #[cfg(not(unix))] -pub async fn wait_for_signal_and_error(shutdown: &mut TaskManager) -> Result<(), SentError> { +pub async fn wait_for_signal_and_error(shutdown: &mut crate::TaskManager) -> Result<(), SentError> { tokio::select! { _ = tokio::signal::ctrl_c() => { log::info!("Received SIGINT"); diff --git a/common/task/src/spawn.rs b/common/task/src/spawn.rs index e0fe98c5dea..8929bf2a2b7 100644 --- a/common/task/src/spawn.rs +++ b/common/task/src/spawn.rs @@ -1,35 +1,79 @@ -use crate::TaskClient; use std::future::Future; +#[cfg(not(target_arch = "wasm32"))] +pub type JoinHandle = tokio::task::JoinHandle; + +// no JoinHandle equivalent in wasm + +#[cfg(target_arch = "wasm32")] +#[derive(Clone, Copy)] +pub struct FakeJoinHandle { + _p: std::marker::PhantomData, +} #[cfg(target_arch = "wasm32")] -pub fn spawn(future: F) +pub type JoinHandle = FakeJoinHandle; + +#[cfg(target_arch = "wasm32")] +#[track_caller] +pub fn spawn_future(future: F) -> JoinHandle where - F: Future + 'static, + F: Future + 'static, { - wasm_bindgen_futures::spawn_local(future); + wasm_bindgen_futures::spawn_local(async move { + // make sure the future outputs `()` + future.await; + }); + FakeJoinHandle { + _p: std::marker::PhantomData, + } } +// Note: prefer spawning tasks directly on the ShutdownManager #[cfg(not(target_arch = "wasm32"))] #[track_caller] -pub fn spawn(future: F) +pub fn spawn_future(future: F) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { - tokio::spawn(future); + tokio::spawn(future) } +// Note: prefer spawning tasks directly on the ShutdownManager +#[cfg(not(target_arch = "wasm32"))] #[track_caller] -pub fn spawn_with_report_error(future: F, mut shutdown: TaskClient) +pub fn spawn_named_future(future: F, name: &str) -> JoinHandle where - F: Future> + Send + 'static, - T: 'static, - E: std::error::Error + Send + Sync + 'static, + F: Future + Send + 'static, + F::Output: Send + 'static, { - let future_that_sends = async move { - if let Err(err) = future.await { - shutdown.send_we_stopped(Box::new(err)); - } - }; - spawn(future_that_sends); + cfg_if::cfg_if! {if #[cfg(all(tokio_unstable, feature="tokio-tracing"))] { + #[allow(clippy::expect_used)] + tokio::task::Builder::new().name(name).spawn(future).expect("failed to spawn future") + } else { + let _ = name; + tracing::debug!(r#"the underlying binary hasn't been built with `RUSTFLAGS="--cfg tokio_unstable"` - the future naming won't do anything"#); + spawn_future(future) + }} +} + +#[cfg(target_arch = "wasm32")] +#[track_caller] +pub fn spawn_named_future(future: F, name: &str) -> JoinHandle +where + F: Future + 'static, +{ + // not supported in wasm + let _ = name; + spawn_future(future) +} + +#[macro_export] +macro_rules! spawn_future { + ($future:expr) => {{ + $crate::spawn_future($future) + }}; + ($future:expr, $name:expr) => {{ + $crate::spawn_named_future($future, $name) + }}; } diff --git a/common/test-utils/src/traits.rs b/common/test-utils/src/traits.rs index 9b95dbd6bf5..e989be6089d 100644 --- a/common/test-utils/src/traits.rs +++ b/common/test-utils/src/traits.rs @@ -32,6 +32,16 @@ pub trait Timeboxed: IntoFuture + Sized { impl Timeboxed for T where T: IntoFuture + Sized {} +pub trait ElapsedExt { + fn has_elapsed(&self) -> bool; +} + +impl ElapsedExt for Result { + fn has_elapsed(&self) -> bool { + self.is_err() + } +} + // those are internal testing traits so we're not concerned about auto traits #[allow(async_fn_in_trait)] pub trait Spawnable: Future + Sized + Send + 'static { diff --git a/common/verloc/src/measurements/listener.rs b/common/verloc/src/measurements/listener.rs index cc74cf89844..8c817ce2293 100644 --- a/common/verloc/src/measurements/listener.rs +++ b/common/verloc/src/measurements/listener.rs @@ -51,24 +51,26 @@ impl PacketListener { info!("Started listening for echo packets on {}", self.address); - while !self.shutdown_token.is_cancelled() { + loop { // cloning the arc as each accepted socket is handled in separate task let connection_handler = Arc::clone(&self.connection_handler); tokio::select! { + biased; + _ = self.shutdown_token.cancelled() => { + trace!("PacketListener: Received shutdown"); + break; + } socket = listener.accept() => { match socket { Ok((socket, remote_addr)) => { debug!("New verloc connection from {remote_addr}"); - let cancel = self.shutdown_token.child_token(format!("handler_{remote_addr}")); - tokio::spawn(async move { cancel.run_until_cancelled(connection_handler.handle_connection(socket, remote_addr)).await }); + let cancel = self.shutdown_token.child_token(); + tokio::spawn(cancel.run_until_cancelled_owned(connection_handler.handle_connection(socket, remote_addr))); } Err(err) => warn!("Failed to accept incoming connection - {err}"), } }, - _ = self.shutdown_token.cancelled() => { - trace!("PacketListener: Received shutdown"); - } } } } diff --git a/common/verloc/src/measurements/measurer.rs b/common/verloc/src/measurements/measurer.rs index 5831551363d..2f5b1364858 100644 --- a/common/verloc/src/measurements/measurer.rs +++ b/common/verloc/src/measurements/measurer.rs @@ -40,12 +40,12 @@ impl VerlocMeasurer { config.packet_timeout, config.connection_timeout, config.delay_between_packets, - shutdown_token.clone_with_suffix("packet_sender"), + shutdown_token.clone(), )), packet_listener: Arc::new(PacketListener::new( config.listening_address, Arc::clone(&identity), - shutdown_token.clone_with_suffix("packet_listener"), + shutdown_token.clone(), )), shutdown_token, config, @@ -92,8 +92,13 @@ impl VerlocMeasurer { .collect::>(); // exhaust the results - while !self.shutdown_token.is_cancelled() { + loop { tokio::select! { + biased; + _ = self.shutdown_token.cancelled() => { + trace!("Shutdown received while measuring"); + return MeasurementOutcome::Shutdown; + } measurement_result = measurement_chunk.next() => { let Some(result) = measurement_result else { // if the stream has finished, it means we got everything we could have gotten @@ -117,10 +122,6 @@ impl VerlocMeasurer { }; chunk_results.push(VerlocNodeResult::new(identity, measurement_result)); }, - _ = self.shutdown_token.cancelled() => { - trace!("Shutdown received while measuring"); - return MeasurementOutcome::Shutdown; - } } } @@ -208,6 +209,7 @@ impl VerlocMeasurer { _ = sleep(self.config.testing_interval) => {}, _ = self.shutdown_token.cancelled() => { trace!("Shutdown received while sleeping"); + break; } } } diff --git a/common/wireguard-private-metadata/server/src/http/mod.rs b/common/wireguard-private-metadata/server/src/http/mod.rs index e3d502d8b5d..4093dece8b9 100644 --- a/common/wireguard-private-metadata/server/src/http/mod.rs +++ b/common/wireguard-private-metadata/server/src/http/mod.rs @@ -1,12 +1,9 @@ // Copyright 2025 - Nym Technologies SA // SPDX-License-Identifier: Apache-2.0 +use nym_wireguard::WgApiWrapper; use std::sync::Arc; - use tokio::task::JoinHandle; -use tokio_util::sync::CancellationToken; - -use nym_wireguard::WgApiWrapper; pub(crate) mod openapi; pub(crate) mod router; @@ -20,7 +17,6 @@ pub(crate) mod state; /// AFTER you have shut down BG tasks (or past their grace period). #[allow(unused)] pub struct ShutdownHandles { - axum_shutdown_button: CancellationToken, /// Tokio JoinHandle for axum server's task axum_join_handle: AxumJoinHandle, /// Wireguard API for kernel interactions @@ -30,13 +26,8 @@ pub struct ShutdownHandles { impl ShutdownHandles { /// Cancellation token is given to Axum server constructor. When the token /// receives a shutdown signal, Axum server will shut down gracefully. - pub fn new( - axum_join_handle: AxumJoinHandle, - wg_api: Arc, - axum_shutdown_button: CancellationToken, - ) -> Self { + pub fn new(axum_join_handle: AxumJoinHandle, wg_api: Arc) -> Self { Self { - axum_shutdown_button, axum_join_handle, wg_api, } diff --git a/common/wireguard-private-metadata/server/src/http/router.rs b/common/wireguard-private-metadata/server/src/http/router.rs index c5936f1f4cf..ac4250e431a 100644 --- a/common/wireguard-private-metadata/server/src/http/router.rs +++ b/common/wireguard-private-metadata/server/src/http/router.rs @@ -7,8 +7,8 @@ use axum::routing::get; use axum::Router; use core::net::SocketAddr; use nym_http_api_common::middleware::logging::log_request_info; +use std::future::Future; use tokio::net::TcpListener; -use tokio_util::sync::WaitForCancellationFutureOwned; use tower_http::cors::CorsLayer; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; @@ -88,14 +88,17 @@ pub struct ApiHttpServer { } impl ApiHttpServer { - pub async fn run(self, receiver: WaitForCancellationFutureOwned) -> Result<(), std::io::Error> { + pub async fn run(self, signal: F) -> Result<(), std::io::Error> + where + F: Future + Send + 'static, + { // into_make_service_with_connect_info allows us to see client ip address axum::serve( self.listener, self.router .into_make_service_with_connect_info::(), ) - .with_graceful_shutdown(receiver) + .with_graceful_shutdown(signal) .await } } diff --git a/common/wireguard/src/lib.rs b/common/wireguard/src/lib.rs index 9fe3371fdea..5bbce1b4b85 100644 --- a/common/wireguard/src/lib.rs +++ b/common/wireguard/src/lib.rs @@ -163,7 +163,7 @@ pub async fn start_wireguard( ecash_manager: Arc, metrics: nym_node_metrics::NymNodeMetrics, peers: Vec, - task_client: nym_task::TaskClient, + shutdown_token: nym_task::ShutdownToken, wireguard_data: WireguardData, ) -> Result, Box> { use base64::{prelude::BASE64_STANDARD, Engine}; @@ -250,7 +250,7 @@ pub async fn start_wireguard( peer_bandwidth_managers, wireguard_data.inner.peer_tx.clone(), wireguard_data.peer_rx, - task_client, + shutdown_token, ); tokio::spawn(async move { controller.run().await }); diff --git a/common/wireguard/src/peer_controller.rs b/common/wireguard/src/peer_controller.rs index 7da4a2336ed..4b294703fb2 100644 --- a/common/wireguard/src/peer_controller.rs +++ b/common/wireguard/src/peer_controller.rs @@ -84,7 +84,7 @@ pub struct PeerController { host_information: Arc>, bw_storage_managers: HashMap, timeout_check_interval: IntervalStream, - task_client: nym_task::TaskClient, + shutdown_token: nym_task::ShutdownToken, } impl PeerController { @@ -97,11 +97,10 @@ impl PeerController { bw_storage_managers: HashMap, request_tx: mpsc::Sender, request_rx: mpsc::Receiver, - task_client: nym_task::TaskClient, + shutdown_token: nym_task::ShutdownToken, ) -> Self { - let timeout_check_interval = tokio_stream::wrappers::IntervalStream::new( - tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK), - ); + let timeout_check_interval = + IntervalStream::new(tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK)); let host_information = Arc::new(RwLock::new(initial_host_information)); for (public_key, (bandwidth_storage_manager, peer)) in bw_storage_managers.iter() { let cached_peer_manager = CachedPeerManager::new(peer); @@ -111,7 +110,7 @@ impl PeerController { cached_peer_manager, bandwidth_storage_manager.clone(), request_tx.clone(), - &task_client, + &shutdown_token, ); let public_key = public_key.clone(); tokio::spawn(async move { @@ -132,7 +131,7 @@ impl PeerController { request_tx, request_rx, timeout_check_interval, - task_client, + shutdown_token, metrics, } } @@ -191,7 +190,7 @@ impl PeerController { cached_peer_manager, bandwidth_storage_manager.clone(), self.request_tx.clone(), - &self.task_client, + &self.shutdown_token, ); self.bw_storage_managers .insert(peer.public_key.clone(), bandwidth_storage_manager); @@ -383,7 +382,7 @@ impl PeerController { *self.host_information.write().await = host; } - _ = self.task_client.recv() => { + _ = self.shutdown_token.cancelled() => { log::trace!("PeerController handler: Received shutdown"); break; } @@ -513,7 +512,7 @@ pub fn start_controller( request_rx: mpsc::Receiver, ) -> ( Arc>, - nym_task::TaskManager, + nym_task::ShutdownManager, ) { use std::sync::Arc; @@ -524,7 +523,7 @@ pub fn start_controller( Box::new(storage.clone()), )); let wg_api = Arc::new(MockWgApi::default()); - let task_manager = nym_task::TaskManager::default(); + let shutdown_manager = nym_task::ShutdownManager::empty_mock(); let mut peer_controller = PeerController::new( ecash_manager, Default::default(), @@ -533,17 +532,17 @@ pub fn start_controller( Default::default(), request_tx, request_rx, - task_manager.subscribe(), + shutdown_manager.child_shutdown_token(), ); tokio::spawn(async move { peer_controller.run().await }); - (storage, task_manager) + (storage, shutdown_manager) } #[cfg(feature = "mock")] -pub async fn stop_controller(mut task_manager: nym_task::TaskManager) { - task_manager.signal_shutdown().unwrap(); - task_manager.wait_for_shutdown().await; +pub async fn stop_controller(mut shutdown_manager: nym_task::ShutdownManager) { + shutdown_manager.send_cancellation(); + shutdown_manager.run_until_shutdown().await; } #[cfg(test)] @@ -553,7 +552,7 @@ mod tests { #[tokio::test] async fn start_and_stop() { let (request_tx, request_rx) = mpsc::channel(1); - let (_, task_manager) = start_controller(request_tx.clone(), request_rx); - stop_controller(task_manager).await; + let (_, shutdown_manager) = start_controller(request_tx.clone(), request_rx); + stop_controller(shutdown_manager).await; } } diff --git a/common/wireguard/src/peer_handle.rs b/common/wireguard/src/peer_handle.rs index 9eda055b2ec..80ec587a192 100644 --- a/common/wireguard/src/peer_handle.rs +++ b/common/wireguard/src/peer_handle.rs @@ -7,7 +7,7 @@ use crate::peer_storage_manager::{CachedPeerManager, PeerInformation}; use defguard_wireguard_rs::{host::Host, key::Key, net::IpAddrMask}; use futures::channel::oneshot; use nym_credential_verification::bandwidth_storage_manager::BandwidthStorageManager; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use nym_wireguard_types::DEFAULT_PEER_TIMEOUT_CHECK; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; @@ -43,7 +43,7 @@ pub struct PeerHandle { bandwidth_storage_manager: SharedBandwidthStorageManager, request_tx: mpsc::Sender, timeout_check_interval: IntervalStream, - task_client: TaskClient, + shutdown_token: ShutdownToken, } impl PeerHandle { @@ -53,13 +53,12 @@ impl PeerHandle { cached_peer: CachedPeerManager, bandwidth_storage_manager: SharedBandwidthStorageManager, request_tx: mpsc::Sender, - task_client: &TaskClient, + shutdown_token: &ShutdownToken, ) -> Self { let timeout_check_interval = tokio_stream::wrappers::IntervalStream::new( tokio::time::interval(DEFAULT_PEER_TIMEOUT_CHECK), ); - let mut task_client = task_client.fork(format!("peer_{public_key}")); - task_client.disarm(); + let shutdown_token = shutdown_token.clone(); PeerHandle { public_key, host_information, @@ -67,7 +66,7 @@ impl PeerHandle { bandwidth_storage_manager, request_tx, timeout_check_interval, - task_client, + shutdown_token, } } @@ -181,8 +180,18 @@ impl PeerHandle { } pub async fn run(&mut self) { - while !self.task_client.is_shutdown() { + loop { tokio::select! { + biased; + _ = self.shutdown_token.cancelled() => { + log::trace!("PeerHandle: Received shutdown"); + if let Err(e) = self.bandwidth_storage_manager.inner().write().await.sync_storage_bandwidth().await { + log::error!("Storage sync failed - {e}, unaccounted bandwidth might have been consumed"); + } + + log::trace!("PeerHandle: Finished shutdown"); + break; + } _ = self.timeout_check_interval.next() => { match self.continue_checking().await { Ok(true) => continue, @@ -201,15 +210,6 @@ impl PeerHandle { }, } } - - _ = self.task_client.recv() => { - log::trace!("PeerHandle: Received shutdown"); - if let Err(e) = self.bandwidth_storage_manager.inner().write().await.sync_storage_bandwidth().await { - log::error!("Storage sync failed - {e}, unaccounted bandwidth might have been consumed"); - } - - log::trace!("PeerHandle: Finished shutdown"); - } } } } diff --git a/contracts/Cargo.lock b/contracts/Cargo.lock index 67097b0840f..86fd21a403f 100644 --- a/contracts/Cargo.lock +++ b/contracts/Cargo.lock @@ -1151,6 +1151,7 @@ dependencies = [ name = "nym-crypto" version = "0.4.0" dependencies = [ + "base64 0.22.1", "bs58", "ed25519-dalek", "nym-pemstore", diff --git a/gateway/src/node/client_handling/embedded_clients/mod.rs b/gateway/src/node/client_handling/embedded_clients/mod.rs index 95af759ed91..9823a9494be 100644 --- a/gateway/src/node/client_handling/embedded_clients/mod.rs +++ b/gateway/src/node/client_handling/embedded_clients/mod.rs @@ -8,8 +8,7 @@ use futures::StreamExt; use nym_network_requester::{GatewayPacketRouter, PacketRouter}; use nym_sphinx::addressing::clients::Recipient; use nym_sphinx::DestinationAddressBytes; -use nym_task::TaskClient; -use tokio::task::JoinHandle; +use nym_task::ShutdownToken; use tracing::{debug, error, trace}; #[derive(Debug)] @@ -53,10 +52,6 @@ impl MessageRouter { } } - pub(crate) fn start_with_shutdown(self, shutdown: TaskClient) -> JoinHandle<()> { - tokio::spawn(self.run_with_shutdown(shutdown)) - } - fn handle_received_messages(&self, messages: Vec>) { if let Err(err) = self.packet_router.route_received(messages) { // TODO: what should we do here? I don't think this could/should ever fail. @@ -65,10 +60,15 @@ impl MessageRouter { } } - pub(crate) async fn run_with_shutdown(mut self, mut shutdown: TaskClient) { + pub(crate) async fn run_with_shutdown(mut self, shutdown: ShutdownToken) { debug!("Started embedded client message router with graceful shutdown support"); - while !shutdown.is_shutdown() { + loop { tokio::select! { + biased; + _ = shutdown.cancelled() => { + trace!("embedded_clients::MessageRouter: Received shutdown"); + break; + } messages = self.mix_receiver.next() => match messages { Some(messages) => self.handle_received_messages(messages), None => { @@ -76,11 +76,6 @@ impl MessageRouter { break; } }, - _ = shutdown.recv_with_delay() => { - trace!("embedded_clients::MessageRouter: Received shutdown"); - debug_assert!(shutdown.is_shutdown()); - break - } } } diff --git a/gateway/src/node/client_handling/websocket/connection_handler/authenticated.rs b/gateway/src/node/client_handling/websocket/connection_handler/authenticated.rs index a152092ef74..36b80293320 100644 --- a/gateway/src/node/client_handling/websocket/connection_handler/authenticated.rs +++ b/gateway/src/node/client_handling/websocket/connection_handler/authenticated.rs @@ -29,7 +29,6 @@ use nym_gateway_storage::traits::SharedKeyGatewayStorage; use nym_node_metrics::events::MetricsEvent; use nym_sphinx::forwarding::packet::MixPacket; use nym_statistics_common::{gateways::GatewaySessionEvent, types::SessionType}; -use nym_task::TaskClient; use nym_validator_client::coconut::EcashApiError; use rand::{random, CryptoRng, Rng}; use std::{process, time::Duration}; @@ -583,7 +582,7 @@ impl AuthenticatedHandler { /// Simultaneously listens for incoming client requests, which realistically should only be /// binary requests to forward sphinx packets or increase bandwidth /// and for sphinx packets received from the mix network that should be sent back to the client. - pub(crate) async fn listen_for_requests(mut self, mut shutdown: TaskClient) + pub(crate) async fn listen_for_requests(mut self) where R: Rng + CryptoRng, S: AsyncRead + AsyncWrite + Unpin, @@ -593,11 +592,8 @@ impl AuthenticatedHandler { // Ping timeout future used to check if the client responded to our ping request let mut ping_timeout: OptionFuture<_> = None.into(); - while !shutdown.is_shutdown() { + loop { tokio::select! { - _ = shutdown.recv() => { - trace!("client_handling::AuthenticatedHandler: received shutdown"); - }, // Received a request to ping the client to check if it's still active tx = self.is_active_request_receiver.next() => { match tx { diff --git a/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs b/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs index 2edbf16b4cc..ab308e7f8e0 100644 --- a/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs +++ b/gateway/src/node/client_handling/websocket/connection_handler/fresh.rs @@ -32,7 +32,7 @@ use nym_gateway_storage::traits::InboxGatewayStorage; use nym_gateway_storage::traits::SharedKeyGatewayStorage; use nym_node_metrics::events::MetricsEvent; use nym_sphinx::DestinationAddressBytes; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use rand::CryptoRng; use std::net::SocketAddr; use std::time::Duration; @@ -127,7 +127,7 @@ pub(crate) struct FreshHandler { pub(crate) shared_state: CommonHandlerState, pub(crate) socket_connection: SocketStream, pub(crate) peer_address: SocketAddr, - pub(crate) shutdown: TaskClient, + pub(crate) shutdown: ShutdownToken, // currently unused (but populated) pub(crate) negotiated_protocol: Option, @@ -145,7 +145,7 @@ impl FreshHandler { conn: S, shared_state: CommonHandlerState, peer_address: SocketAddr, - shutdown: TaskClient, + shutdown: ShutdownToken, ) -> Self { FreshHandler { rng, @@ -917,60 +917,49 @@ impl FreshHandler { pub(crate) async fn handle_until_authenticated_or_failure( mut self, - shutdown: &mut TaskClient, ) -> Option> where S: AsyncRead + AsyncWrite + Unpin + Send, R: CryptoRng + RngCore + Send, { - while !shutdown.is_shutdown() { - let req = tokio::select! { - biased; - _ = shutdown.recv() => { - return None - }, - req = self.wait_for_initial_message() => req, - }; - - let initial_request = match req { - Ok(req) => req, - Err(err) => { - self.send_and_forget_error_response(err).await; - return None; - } - }; - - // see if we managed to register the client through this request - let maybe_auth_res = match self.handle_initial_client_request(initial_request).await { - Ok(maybe_auth_res) => maybe_auth_res, - Err(err) => { - debug!("initial client request handling error: {err}"); - self.send_and_forget_error_response(err).await; - return None; - } - }; - - if let Some(registration_details) = maybe_auth_res { - let (mix_sender, mix_receiver) = mpsc::unbounded(); - // Channel for handlers to ask other handlers if they are still active. - let (is_active_request_sender, is_active_request_receiver) = mpsc::unbounded(); - self.shared_state.active_clients_store.insert_remote( - registration_details.address, - mix_sender, - is_active_request_sender, - registration_details.session_request_timestamp, - ); + let initial_request = match self.wait_for_initial_message().await { + Ok(req) => req, + Err(err) => { + self.send_and_forget_error_response(err).await; + return None; + } + }; - return AuthenticatedHandler::upgrade( - self, - registration_details, - mix_receiver, - is_active_request_receiver, - ) - .await - .inspect_err(|err| error!("failed to upgrade client handler: {err}")) - .ok(); + // see if we managed to register the client through this request + let maybe_auth_res = match self.handle_initial_client_request(initial_request).await { + Ok(maybe_auth_res) => maybe_auth_res, + Err(err) => { + debug!("initial client request handling error: {err}"); + self.send_and_forget_error_response(err).await; + return None; } + }; + + if let Some(registration_details) = maybe_auth_res { + let (mix_sender, mix_receiver) = mpsc::unbounded(); + // Channel for handlers to ask other handlers if they are still active. + let (is_active_request_sender, is_active_request_receiver) = mpsc::unbounded(); + self.shared_state.active_clients_store.insert_remote( + registration_details.address, + mix_sender, + is_active_request_sender, + registration_details.session_request_timestamp, + ); + + AuthenticatedHandler::upgrade( + self, + registration_details, + mix_receiver, + is_active_request_receiver, + ) + .await + .inspect_err(|err| error!("failed to upgrade client handler: {err}")) + .ok(); } None @@ -1031,6 +1020,15 @@ impl FreshHandler { S: AsyncRead + AsyncWrite + Unpin + Send, R: CryptoRng + RngCore + Send, { - super::handle_connection(self).await + let remote = self.peer_address; + let shutdown = self.shutdown.clone(); + tokio::select! { + _ = shutdown.cancelled() => { + trace!("received cancellation") + } + _ = super::handle_connection(self) => { + debug!("finished connection handler for {remote}") + } + } } } diff --git a/gateway/src/node/client_handling/websocket/connection_handler/mod.rs b/gateway/src/node/client_handling/websocket/connection_handler/mod.rs index 0ce52eec005..2fd97600a70 100644 --- a/gateway/src/node/client_handling/websocket/connection_handler/mod.rs +++ b/gateway/src/node/client_handling/websocket/connection_handler/mod.rs @@ -98,15 +98,6 @@ where R: Rng + CryptoRng + Send, S: AsyncRead + AsyncWrite + Unpin + Send, { - // don't accept any new requests if we have already received shutdown - if handle.shutdown.is_shutdown_poll() { - debug!("stopping the handle as we have received a shutdown"); - return; - } - - // If the connection handler abruptly stops, we shouldn't signal global shutdown - handle.shutdown.disarm(); - match tokio::time::timeout( WEBSOCKET_HANDSHAKE_TIMEOUT, handle.perform_websocket_handshake(), @@ -126,13 +117,8 @@ where trace!("managed to perform websocket handshake!"); - let mut shutdown = handle.shutdown.clone(); - - if let Some(auth_handle) = handle - .handle_until_authenticated_or_failure(&mut shutdown) - .await - { - auth_handle.listen_for_requests(shutdown).await + if let Some(auth_handle) = handle.handle_until_authenticated_or_failure().await { + auth_handle.listen_for_requests().await } trace!("the handler is done!"); diff --git a/gateway/src/node/client_handling/websocket/listener.rs b/gateway/src/node/client_handling/websocket/listener.rs index 94e11122efb..b2d234d4a7d 100644 --- a/gateway/src/node/client_handling/websocket/listener.rs +++ b/gateway/src/node/client_handling/websocket/listener.rs @@ -3,19 +3,18 @@ use crate::node::client_handling::websocket::common_state::CommonHandlerState; use crate::node::client_handling::websocket::connection_handler::FreshHandler; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use rand::rngs::OsRng; use std::net::SocketAddr; use std::{io, process}; use tokio::net::TcpStream; -use tokio::task::JoinHandle; use tracing::*; pub struct Listener { address: SocketAddr, maximum_open_connections: usize, shared_state: CommonHandlerState, - shutdown: TaskClient, + shutdown: ShutdownTracker, } impl Listener { @@ -23,7 +22,7 @@ impl Listener { address: SocketAddr, maximum_open_connections: usize, shared_state: CommonHandlerState, - shutdown: TaskClient, + shutdown: ShutdownTracker, ) -> Self { Listener { address, @@ -45,15 +44,12 @@ impl Listener { socket: TcpStream, remote_address: SocketAddr, ) -> FreshHandler { - let shutdown = self - .shutdown - .fork(format!("websocket_handler_{remote_address}")); FreshHandler::new( OsRng, socket, self.shared_state.clone(), remote_address, - shutdown, + self.shutdown.clone_shutdown_token(), ) } @@ -88,16 +84,19 @@ impl Listener { .new_ingress_websocket_client(); // 4. spawn the task handling the client connection - tokio::spawn(async move { - // TODO: refactor it similarly to the mixnet listener on the nym-node - let metrics_ref = handle.shared_state.metrics.clone(); - - // 4.1. handle all client requests until connection gets terminated - handle.start_handling().await; - - // 4.2. decrement the connection counter - metrics_ref.network.disconnected_ingress_websocket_client(); - }); + self.shutdown.try_spawn_named( + async move { + // TODO: refactor it similarly to the mixnet listener on the nym-node + let metrics_ref = handle.shared_state.metrics.clone(); + + // 4.1. handle all client requests until connection gets terminated + handle.start_handling().await; + + // 4.2. decrement the connection counter + metrics_ref.network.disconnected_ingress_websocket_client(); + }, + &format!("Websocket::{remote_address}"), + ); } Err(err) => warn!("failed to accept client connection: {err}"), } @@ -105,7 +104,7 @@ impl Listener { // TODO: change the signature to pub(crate) async fn run(&self, handler: Handler) - pub(crate) async fn run(&mut self) { + pub async fn run(&mut self) { info!("Starting websocket listener at {}", self.address); let tcp_listener = match tokio::net::TcpListener::bind(self.address).await { Ok(listener) => listener, @@ -115,21 +114,18 @@ impl Listener { } }; - while !self.shutdown.is_shutdown() { + let shutdown_token = self.shutdown.clone_shutdown_token(); + loop { tokio::select! { biased; - _ = self.shutdown.recv() => { + _ = shutdown_token.cancelled() => { trace!("client_handling::Listener: received shutdown"); + break } connection = tcp_listener.accept() => { self.try_handle_accepted_connection(connection) } - } } } - - pub fn start(mut self) -> JoinHandle<()> { - tokio::spawn(async move { self.run().await }) - } } diff --git a/gateway/src/node/internal_service_providers/authenticator/mixnet_client.rs b/gateway/src/node/internal_service_providers/authenticator/mixnet_client.rs index 65d9b58e7a4..136e1d27087 100644 --- a/gateway/src/node/internal_service_providers/authenticator/mixnet_client.rs +++ b/gateway/src/node/internal_service_providers/authenticator/mixnet_client.rs @@ -3,7 +3,7 @@ use nym_client_core::{config::disk_persistence::CommonClientPaths, TopologyProvider}; use nym_sdk::{GatewayTransceiver, NymNetworkDetails}; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use crate::node::internal_service_providers::authenticator::{ config::BaseClientConfig, error::AuthenticatorError, @@ -15,7 +15,7 @@ use crate::node::internal_service_providers::authenticator::{ // TODO: refactor this function and its arguments pub async fn create_mixnet_client( config: &BaseClientConfig, - shutdown: TaskClient, + shutdown: ShutdownTracker, custom_transceiver: Option>, custom_topology_provider: Option>, wait_for_gateway: bool, diff --git a/gateway/src/node/internal_service_providers/authenticator/mixnet_listener.rs b/gateway/src/node/internal_service_providers/authenticator/mixnet_listener.rs index c389055696d..9a6c9d3a50e 100644 --- a/gateway/src/node/internal_service_providers/authenticator/mixnet_listener.rs +++ b/gateway/src/node/internal_service_providers/authenticator/mixnet_listener.rs @@ -39,7 +39,7 @@ use nym_sdk::mixnet::{ }; use nym_service_provider_requests_common::{Protocol, ServiceProviderType}; use nym_sphinx::receiver::ReconstructedMessage; -use nym_task::TaskHandle; +use nym_task::ShutdownToken; use nym_wireguard::WireguardGatewayData; use nym_wireguard_types::PeerPublicKey; use rand::{prelude::IteratorRandom, thread_rng}; @@ -70,9 +70,6 @@ pub(crate) struct MixnetListener { // The mixnet client that we use to send and receive packets from the mixnet pub(crate) mixnet_client: nym_sdk::mixnet::MixnetClient, - // The task handle for the main loop - pub(crate) task_handle: TaskHandle, - // Registrations awaiting confirmation pub(crate) registred_and_free: RwLock, @@ -91,7 +88,6 @@ impl MixnetListener { free_private_network_ips: PrivateIPs, wireguard_gateway_data: WireguardGatewayData, mixnet_client: nym_sdk::mixnet::MixnetClient, - task_handle: TaskHandle, ecash_verifier: Arc, ) -> Self { let timeout_check_interval = @@ -99,7 +95,6 @@ impl MixnetListener { MixnetListener { config, mixnet_client, - task_handle, registred_and_free: RwLock::new(RegistredAndFree::new(free_private_network_ips)), peer_manager: PeerManager::new(wireguard_gateway_data), ecash_verifier, @@ -812,14 +807,18 @@ impl MixnetListener { }) } - pub(crate) async fn run(mut self) -> Result<(), AuthenticatorError> { + pub(crate) async fn run( + mut self, + shutdown_token: ShutdownToken, + ) -> Result<(), AuthenticatorError> { tracing::info!("Using authenticator version {CURRENT_VERSION}"); - let mut task_client = self.task_handle.fork("main_loop"); - while !task_client.is_shutdown() { + loop { tokio::select! { - _ = task_client.recv() => { + biased; + _ = shutdown_token.cancelled() => { tracing::debug!("Authenticator [main loop]: received shutdown"); + break; }, _ = self.timeout_check_interval.next() => { if let Err(e) = self.remove_stale_registrations().await { diff --git a/gateway/src/node/internal_service_providers/authenticator/mod.rs b/gateway/src/node/internal_service_providers/authenticator/mod.rs index 12772b6a304..358bc5cb30c 100644 --- a/gateway/src/node/internal_service_providers/authenticator/mod.rs +++ b/gateway/src/node/internal_service_providers/authenticator/mod.rs @@ -7,7 +7,7 @@ use ipnetwork::IpNetwork; use nym_client_core::{HardcodedTopologyProvider, TopologyProvider}; use nym_credential_verification::ecash::EcashManager; use nym_sdk::{mixnet::Recipient, GatewayTransceiver}; -use nym_task::{TaskClient, TaskHandle}; +use nym_task::ShutdownTracker; use nym_wireguard::WireguardGatewayData; use std::{net::IpAddr, path::Path, sync::Arc, time::SystemTime}; @@ -40,7 +40,7 @@ pub struct Authenticator { wireguard_gateway_data: WireguardGatewayData, ecash_verifier: Arc, used_private_network_ips: Vec, - shutdown: Option, + shutdown: ShutdownTracker, on_start: Option>, } @@ -50,6 +50,7 @@ impl Authenticator { wireguard_gateway_data: WireguardGatewayData, used_private_network_ips: Vec, ecash_verifier: Arc, + shutdown: ShutdownTracker, ) -> Self { Self { config, @@ -59,18 +60,11 @@ impl Authenticator { ecash_verifier, wireguard_gateway_data, used_private_network_ips, - shutdown: None, + shutdown, on_start: None, } } - #[must_use] - #[allow(unused)] - pub fn with_shutdown(mut self, shutdown: TaskClient) -> Self { - self.shutdown = Some(shutdown); - self - } - #[must_use] #[allow(unused)] pub fn with_wait_for_gateway(mut self, wait_for_gateway: bool) -> Self { @@ -123,14 +117,10 @@ impl Authenticator { pub async fn run_service_provider(self) -> Result<(), AuthenticatorError> { // Used to notify tasks to shutdown. Not all tasks fully supports this (yet). - let task_handle: TaskHandle = self.shutdown.map(Into::into).unwrap_or_default(); - // Connect to the mixnet let mixnet_client = crate::node::internal_service_providers::authenticator::mixnet_client::create_mixnet_client( &self.config.base, - task_handle - .get_handle() - .named("nym_sdk::MixnetClient[AUTH]"), + self.shutdown.clone(), self.custom_gateway_transceiver, self.custom_topology_provider, self.wait_for_gateway, @@ -162,7 +152,6 @@ impl Authenticator { free_private_network_ips, self.wireguard_gateway_data, mixnet_client, - task_handle, self.ecash_verifier, ); @@ -176,6 +165,8 @@ impl Authenticator { } } - mixnet_listener.run().await + mixnet_listener + .run(self.shutdown.clone_shutdown_token()) + .await } } diff --git a/gateway/src/node/internal_service_providers/ip_packet_router/config/mod.rs b/gateway/src/node/internal_service_providers/ip_packet_router/config/mod.rs new file mode 100644 index 00000000000..6064f902661 --- /dev/null +++ b/gateway/src/node/internal_service_providers/ip_packet_router/config/mod.rs @@ -0,0 +1,52 @@ +use nym_bin_common::logging::LoggingSettings; +use nym_network_defaults::mainnet; +use url::Url; + +mod persistence; + +pub use crate::service_providers::ip_packet_router::config::persistence::IpPacketRouterPaths; +pub use nym_client_core::config::Config as BaseClientConfig; + +#[derive(Debug, Clone, PartialEq)] +pub struct Config { + pub base: BaseClientConfig, + + pub ip_packet_router: IpPacketRouter, + + pub storage_paths: IpPacketRouterPaths, +} + +impl Config { + pub fn validate(&self) -> bool { + // no other sections have explicit requirements (yet) + self.base.validate() + } + + #[doc(hidden)] + pub fn set_no_poisson_process(&mut self) { + self.base.set_no_poisson_process() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct IpPacketRouter { + /// Disable Poisson sending rate. + pub disable_poisson_rate: bool, + + /// Specifies the url for an upstream source of the exit policy used by this node. + pub upstream_exit_policy_url: Option, +} + +impl Default for IpPacketRouter { + fn default() -> Self { + IpPacketRouter { + disable_poisson_rate: true, + #[allow(clippy::expect_used)] + upstream_exit_policy_url: Some( + mainnet::EXIT_POLICY_URL + .parse() + .expect("invalid default exit policy URL"), + ), + } + } +} diff --git a/gateway/src/node/internal_service_providers/mod.rs b/gateway/src/node/internal_service_providers/mod.rs index cd4cd051b93..fa5a08ac192 100644 --- a/gateway/src/node/internal_service_providers/mod.rs +++ b/gateway/src/node/internal_service_providers/mod.rs @@ -17,9 +17,9 @@ use nym_network_requester::error::NetworkRequesterError; use nym_network_requester::NRServiceProviderBuilder; use nym_sdk::mixnet::Recipient; use nym_sdk::{GatewayTransceiver, LocalGateway, PacketRouter}; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use std::fmt::Display; -use tokio::task::JoinHandle; +use std::marker::PhantomData; use tracing::error; pub mod authenticator; @@ -91,12 +91,11 @@ impl RunnableServiceProvider for Authenticator { pub struct ServiceProviderBeingBuilt { on_start_rx: oneshot::Receiver, sp_builder: T, - sp_message_router_builder: SpMessageRouterBuilder, + sp_message_router_builder: SpMessageRouterBuilder, + shutdown_tracker: ShutdownTracker, } pub struct StartedServiceProvider { - pub sp_join_handle: JoinHandle<()>, - pub message_router_join_handle: JoinHandle<()>, pub on_start_data: T::OnStartData, pub handle: LocalEmbeddedClientHandle, } @@ -109,26 +108,31 @@ where pub(crate) fn new( on_start_rx: oneshot::Receiver, sp_builder: T, - sp_message_router_builder: SpMessageRouterBuilder, + sp_message_router_builder: SpMessageRouterBuilder, + shutdown_tracker: ShutdownTracker, ) -> Self { ServiceProviderBeingBuilt { on_start_rx, sp_builder, sp_message_router_builder, + shutdown_tracker, } } pub async fn start_service_provider( mut self, ) -> Result, GatewayError> { - let sp_join_handle = tokio::task::spawn(async move { - if let Err(err) = self.sp_builder.run_service_provider().await { - error!( - "the {} service provider encountered an error: {err}", - T::NAME - ) - } - }); + self.shutdown_tracker.try_spawn_named( + async move { + if let Err(err) = self.sp_builder.run_service_provider().await { + error!( + "the {} service provider encountered an error: {err}", + T::NAME + ) + } + }, + &format!("{}::Provider", T::NAME), + ); // TODO: if something is blocking during SP startup, the below will wait forever // we need to introduce additional timeouts here. @@ -145,13 +149,10 @@ where }; let mix_sender = self.sp_message_router_builder.mix_sender(); - let message_router_join_handle = self - .sp_message_router_builder - .start_message_router(packet_router); + self.sp_message_router_builder + .start_message_router(packet_router, &self.shutdown_tracker); Ok(StartedServiceProvider { - sp_join_handle, - message_router_join_handle, handle: LocalEmbeddedClientHandle::new(on_start_data.address(), mix_sender), on_start_data, }) @@ -180,19 +181,19 @@ impl ExitServiceProviders { } } -pub struct SpMessageRouterBuilder { +pub struct SpMessageRouterBuilder { mix_sender: Option, mix_receiver: MixMessageReceiver, router_receiver: oneshot::Receiver, gateway_transceiver: Option, - shutdown: TaskClient, + + _typ: PhantomData, } -impl SpMessageRouterBuilder { +impl SpMessageRouterBuilder { pub(crate) fn new( node_identity: ed25519::PublicKey, forwarding_channel: MixForwardingSender, - shutdown: TaskClient, ) -> Self { let (mix_sender, mix_receiver) = mpsc::unbounded(); let (router_tx, router_rx) = oneshot::channel(); @@ -204,7 +205,7 @@ impl SpMessageRouterBuilder { mix_receiver, router_receiver: router_rx, gateway_transceiver: Some(transceiver), - shutdown, + _typ: Default::default(), } } @@ -224,7 +225,17 @@ impl SpMessageRouterBuilder { .expect("attempting to use the same mix sender twice") } - fn start_message_router(self, packet_router: PacketRouter) -> JoinHandle<()> { - MessageRouter::new(self.mix_receiver, packet_router).start_with_shutdown(self.shutdown) + fn start_message_router(self, packet_router: PacketRouter, shutdown_tracker: &ShutdownTracker) + where + T: RunnableServiceProvider, + { + let shutdown_token = shutdown_tracker.clone_shutdown_token(); + let message_router = MessageRouter::new(self.mix_receiver, packet_router); + shutdown_tracker.try_spawn_named( + async move { + message_router.run_with_shutdown(shutdown_token).await; + }, + &format!("{}::MessageRouter", T::NAME), + ); } } diff --git a/gateway/src/node/mod.rs b/gateway/src/node/mod.rs index a89a68cb239..e4542b0dd60 100644 --- a/gateway/src/node/mod.rs +++ b/gateway/src/node/mod.rs @@ -1,7 +1,6 @@ // Copyright 2020-2024 - Nym Technologies SA // SPDX-License-Identifier: GPL-3.0-only -use crate::config::Config; use crate::error::GatewayError; use crate::node::client_handling::websocket; use crate::node::internal_service_providers::{ @@ -19,7 +18,7 @@ use nym_network_defaults::NymNetworkDetails; use nym_network_requester::NRServiceProviderBuilder; use nym_node_metrics::events::MetricEventsSender; use nym_node_metrics::NymNodeMetrics; -use nym_task::{ShutdownToken, TaskClient}; +use nym_task::ShutdownTracker; use nym_topology::TopologyProvider; use nym_validator_client::nyxd::{Coin, CosmWasmClient}; use nym_validator_client::{nyxd, DirectSigningHttpRpcNyxdClient}; @@ -35,6 +34,7 @@ pub(crate) mod client_handling; pub(crate) mod internal_service_providers; mod stale_data_cleaner; +use crate::config::Config; use crate::node::internal_service_providers::authenticator::Authenticator; pub use client_handling::active_clients::ActiveClientsStore; pub use nym_gateway_stats_storage::PersistentStatsStorage; @@ -91,9 +91,7 @@ pub struct GatewayTasksBuilder { mnemonic: Arc>, - legacy_task_client: TaskClient, - - shutdown_token: ShutdownToken, + shutdown_tracker: ShutdownTracker, // populated and cached as necessary ecash_manager: Option>, @@ -103,14 +101,6 @@ pub struct GatewayTasksBuilder { wireguard_networks: Option>, } -impl Drop for GatewayTasksBuilder { - fn drop(&mut self) { - // disarm the shutdown as it was already used to construct relevant tasks and we don't want the builder - // to cause shutdown - self.legacy_task_client.disarm(); - } -} - impl GatewayTasksBuilder { #[allow(clippy::too_many_arguments)] pub fn new( @@ -121,8 +111,7 @@ impl GatewayTasksBuilder { metrics_sender: MetricEventsSender, metrics: NymNodeMetrics, mnemonic: Arc>, - legacy_task_client: TaskClient, - shutdown_token: ShutdownToken, + shutdown_tracker: ShutdownTracker, ) -> GatewayTasksBuilder { GatewayTasksBuilder { config, @@ -136,8 +125,7 @@ impl GatewayTasksBuilder { metrics_sender, metrics, mnemonic, - legacy_task_client, - shutdown_token, + shutdown_tracker, ecash_manager: None, wireguard_peers: None, wireguard_networks: None, @@ -227,17 +215,22 @@ impl GatewayTasksBuilder { }; let nyxd_client = self.build_nyxd_signing_client().await?; - let ecash_manager = Arc::new( - EcashManager::new( - handler_config, - nyxd_client, - self.identity_keypair.public_key().to_bytes(), - self.legacy_task_client.fork("ecash_manager"), - self.storage.clone(), - ) - .await?, + + let (ecash_manager, credential_handler) = EcashManager::new( + handler_config, + nyxd_client, + self.identity_keypair.public_key().to_bytes(), + self.storage.clone(), + ) + .await?; + + let shutdown_token = self.shutdown_tracker.clone_shutdown_token(); + self.shutdown_tracker.try_spawn_named( + async move { credential_handler.run(shutdown_token).await }, + "EcashCredentialHandler", ); - Ok(ecash_manager) + + Ok(Arc::new(ecash_manager)) } async fn ecash_manager(&mut self) -> Result, GatewayError> { @@ -274,7 +267,7 @@ impl GatewayTasksBuilder { self.config.gateway.websocket_bind_address, self.config.debug.maximum_open_connections, shared_state, - self.legacy_task_client.fork("websocket"), + self.shutdown_tracker.clone(), )) } @@ -290,19 +283,17 @@ impl GatewayTasksBuilder { let mut message_router_builder = SpMessageRouterBuilder::new( *self.identity_keypair.public_key(), self.mix_packet_sender.clone(), - self.legacy_task_client - .fork("network_requester_message_router"), ); let transceiver = message_router_builder.gateway_transceiver(); let (on_start_tx, on_start_rx) = oneshot::channel(); - let mut nr_builder = NRServiceProviderBuilder::new(nr_opts.config.clone()) - .with_shutdown(self.legacy_task_client.fork("network_requester_sp")) - .with_custom_gateway_transceiver(transceiver) - .with_wait_for_gateway(true) - .with_minimum_gateway_performance(0) - .with_custom_topology_provider(topology_provider) - .with_on_start(on_start_tx); + let mut nr_builder = + NRServiceProviderBuilder::new(nr_opts.config.clone(), self.shutdown_tracker.clone()) + .with_custom_gateway_transceiver(transceiver) + .with_wait_for_gateway(true) + .with_minimum_gateway_performance(0) + .with_custom_topology_provider(topology_provider) + .with_on_start(on_start_tx); if let Some(custom_mixnet) = &nr_opts.custom_mixnet_path { nr_builder = nr_builder.with_stored_topology(custom_mixnet)? @@ -312,6 +303,7 @@ impl GatewayTasksBuilder { on_start_rx, nr_builder, message_router_builder, + self.shutdown_tracker.clone(), )) } @@ -326,18 +318,17 @@ impl GatewayTasksBuilder { let mut message_router_builder = SpMessageRouterBuilder::new( *self.identity_keypair.public_key(), self.mix_packet_sender.clone(), - self.legacy_task_client.fork("ipr_message_router"), ); let transceiver = message_router_builder.gateway_transceiver(); let (on_start_tx, on_start_rx) = oneshot::channel(); - let mut ip_packet_router = IpPacketRouter::new(ip_opts.config.clone()) - .with_shutdown(self.legacy_task_client.fork("ipr_sp")) - .with_custom_gateway_transceiver(Box::new(transceiver)) - .with_wait_for_gateway(true) - .with_minimum_gateway_performance(0) - .with_custom_topology_provider(topology_provider) - .with_on_start(on_start_tx); + let mut ip_packet_router = + IpPacketRouter::new(ip_opts.config.clone(), self.shutdown_tracker.clone()) + .with_custom_gateway_transceiver(Box::new(transceiver)) + .with_wait_for_gateway(true) + .with_minimum_gateway_performance(0) + .with_custom_topology_provider(topology_provider) + .with_on_start(on_start_tx); if let Some(custom_mixnet) = &ip_opts.custom_mixnet_path { ip_packet_router = ip_packet_router.with_stored_topology(custom_mixnet)? @@ -347,6 +338,7 @@ impl GatewayTasksBuilder { on_start_rx, ip_packet_router, message_router_builder, + self.shutdown_tracker.clone(), )) } @@ -432,7 +424,6 @@ impl GatewayTasksBuilder { let mut message_router_builder = SpMessageRouterBuilder::new( *self.identity_keypair.public_key(), self.mix_packet_sender.clone(), - self.legacy_task_client.fork("authenticator_message_router"), ); let transceiver = message_router_builder.gateway_transceiver(); @@ -443,9 +434,9 @@ impl GatewayTasksBuilder { wireguard_data.inner.clone(), used_private_network_ips, ecash_manager, + self.shutdown_tracker.clone(), ) .with_custom_gateway_transceiver(transceiver) - .with_shutdown(self.legacy_task_client.fork("authenticator_sp")) .with_wait_for_gateway(true) .with_minimum_gateway_performance(0) .with_custom_topology_provider(topology_provider) @@ -459,13 +450,13 @@ impl GatewayTasksBuilder { on_start_rx, authenticator_server, message_router_builder, + self.shutdown_tracker.clone(), )) } pub fn build_stale_messages_cleaner(&self) -> StaleMessagesCleaner { StaleMessagesCleaner::new( &self.storage, - self.legacy_task_client.fork("stale_messages_cleaner"), self.config.debug.stale_messages_max_age, self.config.debug.stale_messages_cleaner_run_interval, ) @@ -476,7 +467,7 @@ impl GatewayTasksBuilder { &mut self, ) -> Result, Box> { let _ = self.metrics.clone(); - let _ = self.shutdown_token.clone(); + let _ = self.shutdown_tracker.clone(); unimplemented!("wireguard is not supported on this platform") } @@ -517,26 +508,22 @@ impl GatewayTasksBuilder { ecash_manager, self.metrics.clone(), all_peers, - self.legacy_task_client.fork("wireguard"), + self.shutdown_tracker.clone_shutdown_token(), wireguard_data, ) .await?; let server = router.build_server(&bind_address).await?; - let cancel_token: tokio_util::sync::CancellationToken = (*self.shutdown_token).clone(); - let axum_shutdown_receiver = cancel_token.clone().cancelled_owned(); + let cancel_token = self.shutdown_tracker.clone_shutdown_token(); let server_handle = tokio::spawn(async move { { info!("Started Wireguard Axum HTTP V2 server on {bind_address}"); - server.run(axum_shutdown_receiver).await + server.run(cancel_token.cancelled_owned()).await } }); - let shutdown_handles = nym_wireguard_private_metadata_server::ShutdownHandles::new( - server_handle, - wg_handle, - cancel_token, - ); + let shutdown_handles = + nym_wireguard_private_metadata_server::ShutdownHandles::new(server_handle, wg_handle); Ok(shutdown_handles) } diff --git a/gateway/src/node/stale_data_cleaner.rs b/gateway/src/node/stale_data_cleaner.rs index 271acfe1b55..172b865da65 100644 --- a/gateway/src/node/stale_data_cleaner.rs +++ b/gateway/src/node/stale_data_cleaner.rs @@ -2,16 +2,14 @@ // SPDX-License-Identifier: GPL-3.0-only use nym_gateway_storage::{GatewayStorage, InboxManager}; -use nym_task::TaskClient; +use nym_task::ShutdownToken; use std::error::Error; use std::time::Duration; use time::OffsetDateTime; -use tokio::task::JoinHandle; use tracing::{debug, trace, warn}; pub struct StaleMessagesCleaner { inbox_manager: InboxManager, - task_client: TaskClient, max_message_age: Duration, run_interval: Duration, } @@ -19,13 +17,11 @@ pub struct StaleMessagesCleaner { impl StaleMessagesCleaner { pub(crate) fn new( storage: &GatewayStorage, - task_client: TaskClient, max_message_age: Duration, run_interval: Duration, ) -> Self { StaleMessagesCleaner { inbox_manager: storage.inbox_manager().clone(), - task_client, max_message_age, run_interval, } @@ -36,13 +32,14 @@ impl StaleMessagesCleaner { self.inbox_manager.remove_stale(cutoff).await } - async fn run(&mut self) { + pub async fn run(&mut self, shutdown_token: ShutdownToken) { let mut interval = tokio::time::interval(self.run_interval); - while !self.task_client.is_shutdown() { + loop { tokio::select! { biased; - _ = self.task_client.recv() => { + _ = shutdown_token.cancelled() => { trace!("StaleMessagesCleaner: received shutdown"); + break; } _ = interval.tick() => { if let Err(err) = self.clean_up_stale_messages().await { @@ -53,8 +50,4 @@ impl StaleMessagesCleaner { } debug!("StaleMessagesCleaner: Exiting"); } - - pub fn start(mut self) -> JoinHandle<()> { - tokio::spawn(async move { self.run().await }) - } } diff --git a/nym-api/src/ecash/dkg/controller/mod.rs b/nym-api/src/ecash/dkg/controller/mod.rs index d5a8f804e3f..d5ee8cd14c6 100644 --- a/nym-api/src/ecash/dkg/controller/mod.rs +++ b/nym-api/src/ecash/dkg/controller/mod.rs @@ -282,8 +282,13 @@ impl DkgController { let mut last_polled = OffsetDateTime::now_utc(); let mut last_tick_duration = Default::default(); - while !shutdown.is_cancelled() { + loop { tokio::select! { + biased; + _ = shutdown.cancelled() => { + trace!("DkgController: Received shutdown"); + break; + } _ = interval.tick() => { let now = OffsetDateTime::now_utc(); let tick_duration = now - last_polled; @@ -300,9 +305,6 @@ impl DkgController { error!("failed to update the DKG state: {err}") } } - _ = shutdown.cancelled() => { - trace!("DkgController: Received shutdown"); - } } } } @@ -319,7 +321,7 @@ impl DkgController { where R: Sync + Send + 'static, { - let shutdown_listener = shutdown_manager.clone_token("DKG controller"); + let shutdown_listener = shutdown_manager.clone_shutdown_token(); let dkg_controller = DkgController::new( config, nyxd_client, diff --git a/nym-api/src/ecash/state/mod.rs b/nym-api/src/ecash/state/mod.rs index 78ca592acf1..dd817e49527 100644 --- a/nym-api/src/ecash/state/mod.rs +++ b/nym-api/src/ecash/state/mod.rs @@ -138,7 +138,7 @@ impl EcashState { EcashBackgroundStateCleaner::new( global_config, storage.clone(), - shutdown_manager.clone_token("ecash-state-data-cleaner"), + shutdown_manager.clone_shutdown_token(), ), ), global: GlobalEcachState::new(contract_address), diff --git a/nym-api/src/epoch_operations/mod.rs b/nym-api/src/epoch_operations/mod.rs index 2322d4be249..e8deb036a68 100644 --- a/nym-api/src/epoch_operations/mod.rs +++ b/nym-api/src/epoch_operations/mod.rs @@ -266,7 +266,7 @@ impl EpochAdvancer { described_cache, storage.to_owned(), ); - let shutdown_listener = shutdown_manager.clone_token("epoch-advancer"); + let shutdown_listener = shutdown_manager.clone_shutdown_token(); tokio::spawn(async move { epoch_advancer.run(shutdown_listener).await }); } } diff --git a/nym-api/src/key_rotation/mod.rs b/nym-api/src/key_rotation/mod.rs index 44f11be7ef5..7d7c5220ccd 100644 --- a/nym-api/src/key_rotation/mod.rs +++ b/nym-api/src/key_rotation/mod.rs @@ -114,11 +114,12 @@ impl KeyRotationController { self.contract_cache.naive_wait_for_initial_values().await; self.handle_contract_cache_update().await; - while !shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = shutdown_token.cancelled() => { trace!("KeyRotationController: Received shutdown"); + break; } _ = self.contract_cache_watcher.changed() => { self.handle_contract_cache_update().await diff --git a/nym-api/src/network_monitor/mod.rs b/nym-api/src/network_monitor/mod.rs index eaa6f8667e5..b5e8309a752 100644 --- a/nym-api/src/network_monitor/mod.rs +++ b/nym-api/src/network_monitor/mod.rs @@ -25,7 +25,7 @@ use nym_crypto::asymmetric::{ed25519, x25519}; use nym_sphinx::acknowledgements::AckKey; use nym_sphinx::params::PacketType; use nym_sphinx::receiver::MessageReceiver; -use nym_task::ShutdownManager; +use nym_task::{ShutdownManager, ShutdownToken}; use std::sync::Arc; use tracing::info; @@ -84,6 +84,7 @@ impl<'a> NetworkMonitorBuilder<'a> { pub(crate) async fn build( self, + shutdown_token: ShutdownToken, ) -> NetworkMonitorRunnables { // TODO: those keys change constant throughout the whole execution of the monitor. // and on top of that, they are used with ALL the gateways -> presumably this should change @@ -127,6 +128,7 @@ impl<'a> NetworkMonitorBuilder<'a> { gateway_status_update_sender, Arc::clone(&identity_keypair), bandwidth_controller, + shutdown_token, ); let received_processor = new_received_processor( @@ -170,9 +172,9 @@ impl NetworkMonitorRunnables { pub(crate) fn spawn_tasks(self, shutdown: &ShutdownManager) { let mut packet_receiver = self.packet_receiver; let mut monitor = self.monitor; - let shutdown_listener = shutdown.clone_token("NM-packet-receiver"); + let shutdown_listener = shutdown.clone_shutdown_token(); tokio::spawn(async move { packet_receiver.run(shutdown_listener).await }); - let shutdown_listener = shutdown.clone_token("NM-main"); + let shutdown_listener = shutdown.clone_shutdown_token(); tokio::spawn(async move { monitor.run(shutdown_listener).await }); } } @@ -202,12 +204,14 @@ fn new_packet_sender( gateways_status_updater: GatewayClientUpdateSender, local_identity: Arc, bandwidth_controller: BandwidthController, + shutdown_token: ShutdownToken, ) -> PacketSender { PacketSender::new( config, gateways_status_updater, local_identity, bandwidth_controller, + shutdown_token, ) } @@ -252,6 +256,7 @@ pub(crate) async fn start( nyxd_client, ); info!("Starting network monitor..."); - let runnables: NetworkMonitorRunnables = monitor_builder.build().await; + let runnables: NetworkMonitorRunnables = + monitor_builder.build(shutdown.clone_shutdown_token()).await; runnables.spawn_tasks(shutdown); } diff --git a/nym-api/src/network_monitor/monitor/mod.rs b/nym-api/src/network_monitor/monitor/mod.rs index d51793acb96..f51639897bb 100644 --- a/nym-api/src/network_monitor/monitor/mod.rs +++ b/nym-api/src/network_monitor/monitor/mod.rs @@ -334,20 +334,24 @@ impl Monitor { .await; let mut run_interval = tokio::time::interval(self.run_interval); - while !shutdown_token.is_cancelled() { + loop { tokio::select! { + biased; + _ = shutdown_token.cancelled() => { + trace!("UpdateHandler: Received shutdown"); + break; + } _ = run_interval.tick() => { tokio::select! { biased; _ = shutdown_token.cancelled() => { trace!("UpdateHandler: Received shutdown"); + break; } _ = self.test_run() => (), } } - _ = shutdown_token.cancelled() => { - trace!("UpdateHandler: Received shutdown"); - } + } } } diff --git a/nym-api/src/network_monitor/monitor/receiver.rs b/nym-api/src/network_monitor/monitor/receiver.rs index e5f27a1c0af..2428713c588 100644 --- a/nym-api/src/network_monitor/monitor/receiver.rs +++ b/nym-api/src/network_monitor/monitor/receiver.rs @@ -58,11 +58,12 @@ impl PacketReceiver { } pub(crate) async fn run(&mut self, shutdown_token: ShutdownToken) { - while !shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = shutdown_token.cancelled() => { trace!("UpdateHandler: Received shutdown"); + break; } // unwrap here is fine as it can only return a `None` if the PacketSender has died // and if that was the case, then the entire monitor is already in an undefined state diff --git a/nym-api/src/network_monitor/monitor/sender.rs b/nym-api/src/network_monitor/monitor/sender.rs index fa6ba454c9e..358fb8e3167 100644 --- a/nym-api/src/network_monitor/monitor/sender.rs +++ b/nym-api/src/network_monitor/monitor/sender.rs @@ -20,6 +20,7 @@ use nym_gateway_client::{ AcknowledgementReceiver, GatewayClient, MixnetMessageReceiver, PacketRouter, SharedGatewayKey, }; use nym_sphinx::forwarding::packet::MixPacket; +use nym_task::ShutdownToken; use pin_project::pin_project; use sqlx::__rt::timeout; use std::mem; @@ -91,6 +92,7 @@ impl GatewayPackets { struct FreshGatewayClientData { gateways_status_updater: GatewayClientUpdateSender, local_identity: Arc, + shutdown_token: ShutdownToken, gateway_response_timeout: Duration, bandwidth_controller: BandwidthController, disabled_credentials_mode: bool, @@ -127,11 +129,13 @@ impl PacketSender { gateways_status_updater: GatewayClientUpdateSender, local_identity: Arc, bandwidth_controller: BandwidthController, + shutdown_token: ShutdownToken, ) -> Self { PacketSender { fresh_gateway_client_data: Arc::new(FreshGatewayClientData { gateways_status_updater, local_identity, + shutdown_token, gateway_response_timeout: config.network_monitor.debug.gateway_response_timeout, bandwidth_controller, disabled_credentials_mode: config.network_monitor.debug.disabled_credentials_mode, @@ -154,10 +158,6 @@ impl PacketSender { GatewayClientHandle, (MixnetMessageReceiver, AcknowledgementReceiver), ) { - // I think the proper one should be passed around instead... - let task_client = - nym_task::TaskClient::dummy().named(format!("gateway-{}", config.gateway_identity)); - let (message_sender, message_receiver) = mpsc::unbounded(); // currently we do not care about acks at all, but we must keep the channel alive @@ -167,7 +167,7 @@ impl PacketSender { let gateway_packet_router = PacketRouter::new( ack_sender, message_sender, - task_client.fork("packet_router"), + fresh_gateway_client_data.shutdown_token.clone(), ); let shared_keys = fresh_gateway_client_data @@ -186,11 +186,11 @@ impl PacketSender { Some(fresh_gateway_client_data.bandwidth_controller.clone()), nym_statistics_common::clients::ClientStatsSender::new( None, - task_client.fork("client_stats_sender"), + fresh_gateway_client_data.shutdown_token.clone(), ), #[cfg(unix)] None, - task_client, + fresh_gateway_client_data.shutdown_token.clone(), ); ( diff --git a/nym-api/src/node_performance/contract_cache/mod.rs b/nym-api/src/node_performance/contract_cache/mod.rs index be430935a58..518f00a789f 100644 --- a/nym-api/src/node_performance/contract_cache/mod.rs +++ b/nym-api/src/node_performance/contract_cache/mod.rs @@ -45,7 +45,7 @@ pub(crate) async fn start_cache_refresher( .with_update_fn(move |main_cache, update| { refresher_update_fn(main_cache, update, values_to_retain) }) - .start(shutdown_manager.clone_token("performance-contract-cache-refresher")); + .start(shutdown_manager.clone_shutdown_token()); Ok(warmed_up_cache) } diff --git a/nym-api/src/node_status_api/cache/refresher.rs b/nym-api/src/node_status_api/cache/refresher.rs index ab323f918fe..a3af35d351b 100644 --- a/nym-api/src/node_status_api/cache/refresher.rs +++ b/nym-api/src/node_status_api/cache/refresher.rs @@ -69,11 +69,12 @@ impl NodeStatusCacheRefresher { pub async fn run(&mut self, shutdown_token: ShutdownToken) { let mut last_update = OffsetDateTime::now_utc(); let mut fallback_interval = time::interval(self.fallback_caching_interval); - while !shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = shutdown_token.cancelled() => { trace!("NodeStatusCacheRefresher: Received shutdown"); + break; } // Update node status cache when the contract cache / describe cache is updated Ok(_) = self.mixnet_contract_cache_listener.changed() => { @@ -81,6 +82,7 @@ impl NodeStatusCacheRefresher { _ = self.maybe_refresh(&mut fallback_interval, &mut last_update) => (), _ = shutdown_token.cancelled() => { trace!("NodeStatusCacheRefresher: Received shutdown"); + break; } } } @@ -89,6 +91,7 @@ impl NodeStatusCacheRefresher { _ = self.maybe_refresh(&mut fallback_interval, &mut last_update) => (), _ = shutdown_token.cancelled() => { trace!("NodeStatusCacheRefresher: Received shutdown"); + break; } } } @@ -99,6 +102,7 @@ impl NodeStatusCacheRefresher { _ = self.maybe_refresh(&mut fallback_interval, &mut last_update) => (), _ = shutdown_token.cancelled() => { trace!("NodeStatusCacheRefresher: Received shutdown"); + break; } } } diff --git a/nym-api/src/node_status_api/mod.rs b/nym-api/src/node_status_api/mod.rs index 4473e0c1551..7820d9a9536 100644 --- a/nym-api/src/node_status_api/mod.rs +++ b/nym-api/src/node_status_api/mod.rs @@ -50,6 +50,6 @@ pub(crate) fn start_cache_refresh( described_cache_cache_listener, performance_provider, ); - let shutdown_listener = shutdown_manager.clone_token("node-status-refresher"); + let shutdown_listener = shutdown_manager.clone_shutdown_token(); tokio::spawn(async move { nym_api_cache_refresher.run(shutdown_listener).await }); } diff --git a/nym-api/src/node_status_api/uptime_updater.rs b/nym-api/src/node_status_api/uptime_updater.rs index c39d4785b53..bfa1d661014 100644 --- a/nym-api/src/node_status_api/uptime_updater.rs +++ b/nym-api/src/node_status_api/uptime_updater.rs @@ -98,11 +98,12 @@ impl HistoricalUptimeUpdater { let start = Instant::now() + time_left; let mut interval = interval_at(start, ONE_DAY); - while !shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = shutdown_token.cancelled() => { trace!("UpdateHandler: Received shutdown"); + break; } _ = interval.tick() => { info!("updating historical uptimes of nodes"); @@ -118,7 +119,7 @@ impl HistoricalUptimeUpdater { pub(crate) fn start(storage: NymApiStorage, shutdown: &ShutdownManager) { let uptime_updater = HistoricalUptimeUpdater::new(storage); - let shutdown_listener = shutdown.child_token("uptime-updater"); + let shutdown_listener = shutdown.child_shutdown_token(); tokio::spawn(async move { uptime_updater.run(shutdown_listener).await }); } } diff --git a/nym-api/src/signers_cache/mod.rs b/nym-api/src/signers_cache/mod.rs index c902a4d52c1..5b647f8699b 100644 --- a/nym-api/src/signers_cache/mod.rs +++ b/nym-api/src/signers_cache/mod.rs @@ -23,7 +23,7 @@ pub(crate) fn start_refresher( .named("signers-cache-refresher"); let shared_cache = refresher.get_shared_cache(); refresher.start_with_delay( - shutdown_manager.clone_token("signers-cache-refresher"), + shutdown_manager.clone_shutdown_token(), config.debug.refresher_start_delay, ); shared_cache diff --git a/nym-api/src/support/caching/refresher.rs b/nym-api/src/support/caching/refresher.rs index 27d8480c035..2b21397c891 100644 --- a/nym-api/src/support/caching/refresher.rs +++ b/nym-api/src/support/caching/refresher.rs @@ -239,11 +239,12 @@ where self.provider.wait_until_ready().await; let mut refresh_interval = interval(self.refreshing_interval); - while !shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = shutdown_token.cancelled() => { - trace!("{}: Received shutdown", self.name) + trace!("{}: Received shutdown", self.name); + break } _ = refresh_interval.tick() => self.refresh(&shutdown_token).await, // note: `Notify` is not cancellation safe, HOWEVER, there's only one listener, diff --git a/nym-api/src/support/cli/run.rs b/nym-api/src/support/cli/run.rs index 0ae75fe23e6..6c8750149ad 100644 --- a/nym-api/src/support/cli/run.rs +++ b/nym-api/src/support/cli/run.rs @@ -120,7 +120,7 @@ pub(crate) struct Args { } async fn start_nym_api_tasks(config: &Config) -> anyhow::Result { - let shutdown_manager = ShutdownManager::new("nym-api") + let shutdown_manager = ShutdownManager::build_new_default()? .with_shutdown_duration(Duration::from_secs(TASK_MANAGER_TIMEOUT_S)); let nyxd_client = nyxd::Client::new(config)?; @@ -256,8 +256,8 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result let describe_cache_refresh_requester = describe_cache_refresher.refresh_requester(); - let describe_cache_watcher = describe_cache_refresher - .start_with_watcher(shutdown_manager.clone_token("node-self-described-data-refresher")); + let describe_cache_watcher = + describe_cache_refresher.start_with_watcher(shutdown_manager.clone_shutdown_token()); let performance_provider = if config.performance_provider.use_performance_contract_data { if network_details @@ -289,8 +289,8 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result }; // start all the caches first - let contract_cache_watcher = mixnet_contract_cache_refresher - .start_with_watcher(shutdown_manager.clone_token("contracts-data-refresher")); + let contract_cache_watcher = + mixnet_contract_cache_refresher.start_with_watcher(shutdown_manager.clone_shutdown_token()); node_status_api::start_cache_refresh( &config.node_status_api, @@ -359,12 +359,12 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result contract_cache_watcher, mixnet_contract_cache_state, ) - .start(shutdown_manager.clone_token("KeyRotationController")); + .start(shutdown_manager.clone_shutdown_token()); let bind_address = config.base.bind_address.to_owned(); let server = router.build_server(&bind_address).await?; - let http_shutdown = shutdown_manager.clone_token("axum-http"); + let http_shutdown = shutdown_manager.clone_shutdown_token(); tokio::spawn(async move { { info!("Started Axum HTTP V2 server on {bind_address}"); @@ -372,7 +372,7 @@ async fn start_nym_api_tasks(config: &Config) -> anyhow::Result } }); - shutdown_manager.close(); + shutdown_manager.close_tracker(); Ok(shutdown_manager) } @@ -385,7 +385,7 @@ pub(crate) async fn execute(args: Args) -> anyhow::Result<()> { config.validate()?; - let shutdown_manager = start_nym_api_tasks(&config).await?; + let mut shutdown_manager = start_nym_api_tasks(&config).await?; shutdown_manager.run_until_shutdown().await; Ok(()) diff --git a/nym-node/Cargo.toml b/nym-node/Cargo.toml index 0dceb6aedbe..d13237cab1a 100644 --- a/nym-node/Cargo.toml +++ b/nym-node/Cargo.toml @@ -40,6 +40,7 @@ tracing-indicatif = { workspace = true } tracing-subscriber.workspace = true tokio = { workspace = true, features = ["macros", "sync", "rt-multi-thread"] } tokio-util = { workspace = true, features = ["codec"] } +tokio-stream = { workspace = true } toml = { workspace = true } url = { workspace = true, features = ["serde"] } zeroize = { workspace = true, features = ["zeroize_derive"] } @@ -129,7 +130,7 @@ criterion = { workspace = true, features = ["async_tokio"] } rand_chacha = { workspace = true } [features] -tokio-console = ["console-subscriber"] +tokio-console = ["console-subscriber", "nym-task/tokio-tracing"] [lints] workspace = true diff --git a/nym-node/src/config/helpers.rs b/nym-node/src/config/helpers.rs index 53769f9d809..cdbffb0f257 100644 --- a/nym-node/src/config/helpers.rs +++ b/nym-node/src/config/helpers.rs @@ -159,7 +159,6 @@ pub fn gateway_tasks_config(config: &Config) -> GatewayTasksConfig { .to_common_client_paths(), ip_packet_router_description: Default::default(), }, - logging: config.logging, }, custom_mixnet_path: None, diff --git a/nym-node/src/node/helpers.rs b/nym-node/src/node/helpers.rs index bdb00df7b1f..dbd8b15086f 100644 --- a/nym-node/src/node/helpers.rs +++ b/nym-node/src/node/helpers.rs @@ -187,7 +187,7 @@ pub(crate) async fn get_current_rotation_id( nym_apis: &[Url], fallback_nyxd: &[Url], ) -> Result { - let apis_client = NymApisClient::new(nym_apis, ShutdownToken::ephemeral())?; + let apis_client = NymApisClient::new(nym_apis, ShutdownToken::default())?; if let Ok(rotation_info) = apis_client.get_key_rotation_info().await.map(|r| r.details) { if rotation_info.is_epoch_stuck() { return Err(NymNodeError::StuckEpoch); diff --git a/nym-node/src/node/key_rotation/controller.rs b/nym-node/src/node/key_rotation/controller.rs index 0aa8aa5fa56..c7aee7d671f 100644 --- a/nym-node/src/node/key_rotation/controller.rs +++ b/nym-node/src/node/key_rotation/controller.rs @@ -349,7 +349,7 @@ impl KeyRotationController { let state_update_future = sleep(next_action.until_deadline()); pin_mut!(state_update_future); - while !self.shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = self.shutdown_token.cancelled() => { diff --git a/nym-node/src/node/metrics/aggregator.rs b/nym-node/src/node/metrics/aggregator.rs index 7d67e925c20..3be11cedcf7 100644 --- a/nym-node/src/node/metrics/aggregator.rs +++ b/nym-node/src/node/metrics/aggregator.rs @@ -12,7 +12,6 @@ use std::any::TypeId; use std::collections::HashMap; use std::ops::DerefMut; use std::time::Duration; -use tokio::task::JoinHandle; use tokio::time::{interval_at, Instant}; use tracing::{debug, error, trace, warn}; @@ -25,11 +24,10 @@ pub(crate) struct MetricsAggregator { // registered_handlers: HashMap>, event_sender: MetricEventsSender, event_receiver: MetricEventsReceiver, - shutdown: ShutdownToken, } impl MetricsAggregator { - pub fn new(handlers_update_interval: Duration, shutdown: ShutdownToken) -> Self { + pub fn new(handlers_update_interval: Duration) -> Self { let (event_sender, event_receiver) = events_channels(); MetricsAggregator { @@ -37,7 +35,6 @@ impl MetricsAggregator { registered_handlers: Default::default(), event_sender, event_receiver, - shutdown, } } @@ -106,7 +103,7 @@ impl MetricsAggregator { } } - pub async fn run(&mut self) { + pub async fn run(&mut self, shutdown_token: ShutdownToken) { self.on_start().await; let start = Instant::now() + self.handlers_update_interval; @@ -117,7 +114,7 @@ impl MetricsAggregator { loop { tokio::select! { biased; - _ = self.shutdown.cancelled() => { + _ = shutdown_token.cancelled() => { debug!("MetricsAggregator: Received shutdown"); break; } @@ -144,8 +141,4 @@ impl MetricsAggregator { } trace!("MetricsAggregator: Exiting"); } - - pub fn start(mut self) -> JoinHandle<()> { - tokio::spawn(async move { self.run().await }) - } } diff --git a/nym-node/src/node/metrics/console_logger.rs b/nym-node/src/node/metrics/console_logger.rs index c7db72960b2..35b25ff153b 100644 --- a/nym-node/src/node/metrics/console_logger.rs +++ b/nym-node/src/node/metrics/console_logger.rs @@ -1,15 +1,15 @@ // Copyright 2024 - Nym Technologies SA // SPDX-License-Identifier: GPL-3.0-only +use futures::StreamExt; use human_repr::HumanCount; use human_repr::HumanThroughput; use nym_node_metrics::NymNodeMetrics; -use nym_task::ShutdownToken; use std::time::Duration; use time::OffsetDateTime; -use tokio::task::JoinHandle; use tokio::time::{interval_at, Instant}; -use tracing::{info, trace}; +use tokio_stream::wrappers::IntervalStream; +use tracing::{error, info, trace}; struct AtLastUpdate { time: OffsetDateTime, @@ -49,20 +49,14 @@ pub(crate) struct ConsoleLogger { logging_delay: Duration, at_last_update: AtLastUpdate, metrics: NymNodeMetrics, - shutdown: ShutdownToken, } impl ConsoleLogger { - pub(crate) fn new( - logging_delay: Duration, - metrics: NymNodeMetrics, - shutdown: ShutdownToken, - ) -> Self { + pub(crate) fn new(logging_delay: Duration, metrics: NymNodeMetrics) -> Self { ConsoleLogger { logging_delay, at_last_update: AtLastUpdate::new(), metrics, - shutdown, } } @@ -123,23 +117,19 @@ impl ConsoleLogger { // TODO: add websocket-client traffic } - async fn run(&mut self) { + pub(crate) async fn run(&mut self) { trace!("Starting ConsoleLogger"); - let mut interval = interval_at(Instant::now() + self.logging_delay, self.logging_delay); - loop { - tokio::select! { - biased; - _ = self.shutdown.cancelled() => { - trace!("ConsoleLogger: Received shutdown"); - break - } - _ = interval.tick() => self.log_running_stats().await, - }; + + let mut stream = IntervalStream::new(interval_at( + Instant::now() + self.logging_delay, + self.logging_delay, + )); + + while stream.next().await.is_some() { + self.log_running_stats().await } - trace!("ConsoleLogger: Exiting"); - } - pub(crate) fn start(mut self) -> JoinHandle<()> { - tokio::spawn(async move { self.run().await }) + // this should never get triggered + error!("console logger interval has been exhausted!") } } diff --git a/nym-node/src/node/mixnet/handler.rs b/nym-node/src/node/mixnet/handler.rs index aca227455a5..6a14277c1df 100644 --- a/nym-node/src/node/mixnet/handler.rs +++ b/nym-node/src/node/mixnet/handler.rs @@ -91,7 +91,6 @@ impl Drop for ConnectionHandler { impl ConnectionHandler { pub(crate) fn new(shared: &SharedData, remote_address: SocketAddr) -> Self { - let shutdown = shared.shutdown.child_token(remote_address.to_string()); shared.metrics.network.new_active_ingress_mixnet_client(); ConnectionHandler { @@ -103,7 +102,7 @@ impl ConnectionHandler { final_hop: shared.final_hop.clone(), noise_config: shared.noise_config.clone(), metrics: shared.metrics.clone(), - shutdown, + shutdown_token: shared.shutdown_token.child_token(), }, remote_address, pending_packets: PendingReplayCheckPackets::new(), @@ -369,7 +368,7 @@ impl ConnectionHandler { Some(Err(_)) => { // our mutex got poisoned - we have to shut down error!("CRITICAL FAILURE: replay bloomfilter mutex poisoning!"); - self.shared.shutdown.cancel(); + self.shared.shutdown_token.cancel(); return false; } }; @@ -394,7 +393,7 @@ impl ConnectionHandler { else { // our mutex got poisoned - we have to shut down error!("CRITICAL FAILURE: replay bloomfilter mutex poisoning!"); - self.shared.shutdown.cancel(); + self.shared.shutdown_token.cancel(); return; }; @@ -489,7 +488,7 @@ impl ConnectionHandler { loop { tokio::select! { biased; - _ = self.shared.shutdown.cancelled() => { + _ = self.shared.shutdown_token.cancelled() => { trace!("connection handler: received shutdown"); break } diff --git a/nym-node/src/node/mixnet/listener.rs b/nym-node/src/node/mixnet/listener.rs index f5900414b08..7fdca155a71 100644 --- a/nym-node/src/node/mixnet/listener.rs +++ b/nym-node/src/node/mixnet/listener.rs @@ -4,12 +4,10 @@ use crate::node::mixnet::SharedData; use nym_task::ShutdownToken; use std::net::SocketAddr; -use tokio::task::JoinHandle; use tracing::{debug, error, info, trace}; pub(crate) struct Listener { bind_address: SocketAddr, - shutdown: ShutdownToken, shared_data: SharedData, } @@ -17,19 +15,18 @@ impl Listener { pub(crate) fn new(bind_address: SocketAddr, shared_data: SharedData) -> Self { Listener { bind_address, - shutdown: shared_data.shutdown.clone_with_suffix("socket-listener"), shared_data, } } - pub(crate) async fn run(&mut self) { + pub(crate) async fn run(&mut self, shutdown: ShutdownToken) { info!("attempting to run mixnet listener on {}", self.bind_address); let tcp_listener = match tokio::net::TcpListener::bind(self.bind_address).await { Ok(listener) => listener, Err(err) => { error!("Failed to bind to {}: {err}. Are you sure nothing else is running on the specified port and your user has sufficient permission to bind to the requested address?", self.bind_address); - self.shutdown.cancel(); + shutdown.cancel(); return; } }; @@ -37,7 +34,7 @@ impl Listener { loop { tokio::select! { biased; - _ = self.shutdown.cancelled() => { + _ = shutdown.cancelled() => { trace!("mixnet listener: received shutdown"); break } @@ -48,8 +45,4 @@ impl Listener { } debug!("mixnet socket listener: Exiting"); } - - pub(crate) fn start(mut self) -> JoinHandle<()> { - tokio::spawn(async move { self.run().await }) - } } diff --git a/nym-node/src/node/mixnet/packet_forwarding/mod.rs b/nym-node/src/node/mixnet/packet_forwarding/mod.rs index e397124ee08..d76e4cea555 100644 --- a/nym-node/src/node/mixnet/packet_forwarding/mod.rs +++ b/nym-node/src/node/mixnet/packet_forwarding/mod.rs @@ -26,16 +26,10 @@ pub struct PacketForwarder { packet_sender: MixForwardingSender, packet_receiver: MixForwardingReceiver, - shutdown: ShutdownToken, } impl PacketForwarder { - pub fn new( - client: C, - routing_filter: F, - metrics: NymNodeMetrics, - shutdown: ShutdownToken, - ) -> Self { + pub fn new(client: C, routing_filter: F, metrics: NymNodeMetrics) -> Self { let (packet_sender, packet_receiver) = mix_forwarding_channels(); PacketForwarder { @@ -45,7 +39,6 @@ impl PacketForwarder { routing_filter, packet_sender, packet_receiver, - shutdown, } } @@ -127,7 +120,7 @@ impl PacketForwarder { .update_packet_forwarder_queue_size(channel_size) } - pub async fn run(&mut self) + pub async fn run(&mut self, shutdown_token: ShutdownToken) where C: SendWithoutResponse, F: RoutingFilter, @@ -137,7 +130,7 @@ impl PacketForwarder { loop { tokio::select! { biased; - _ = self.shutdown.cancelled() => { + _ = shutdown_token.cancelled() => { debug!("PacketForwarder: Received shutdown"); break; } diff --git a/nym-node/src/node/mixnet/shared/mod.rs b/nym-node/src/node/mixnet/shared/mod.rs index 7dee952d44f..67dea46a7a8 100644 --- a/nym-node/src/node/mixnet/shared/mod.rs +++ b/nym-node/src/node/mixnet/shared/mod.rs @@ -63,7 +63,7 @@ impl ProcessingConfig { } } -// explicitly do NOT derive clone as we want to manually apply relevant suffixes to the task clients +// explicitly do NOT derive clone as we want the childs to use CHILD shutdown tokens pub(crate) struct SharedData { pub(super) processing_config: ProcessingConfig, pub(super) sphinx_keys: ActiveSphinxKeys, @@ -79,7 +79,8 @@ pub(crate) struct SharedData { pub(super) noise_config: NoiseConfig, pub(super) metrics: NymNodeMetrics, - pub(super) shutdown: ShutdownToken, + + pub(super) shutdown_token: ShutdownToken, } fn convert_to_metrics_version(processed: MixPacketVersion) -> PacketKind { @@ -99,7 +100,7 @@ impl SharedData { final_hop: SharedFinalHopData, noise_config: NoiseConfig, metrics: NymNodeMetrics, - shutdown: ShutdownToken, + shutdown_token: ShutdownToken, ) -> Self { SharedData { processing_config, @@ -109,7 +110,7 @@ impl SharedData { final_hop, noise_config, metrics, - shutdown, + shutdown_token, } } @@ -188,10 +189,10 @@ impl SharedData { .mixnet_forwarder .forward_packet(PacketToForward::new(packet, delay_until)) .is_err() - && !self.shutdown.is_cancelled() + && !self.shutdown_token.is_cancelled() { error!("failed to forward sphinx packet on the channel while the process is not going through the shutdown!"); - self.shutdown.cancel(); + self.shutdown_token.cancel(); } } diff --git a/nym-node/src/node/mod.rs b/nym-node/src/node/mod.rs index 2fca2469979..fc3ad47599e 100644 --- a/nym-node/src/node/mod.rs +++ b/nym-node/src/node/mod.rs @@ -54,7 +54,7 @@ use nym_noise::config::{NoiseConfig, NoiseNetworkView}; use nym_noise_keys::VersionedNoiseKey; use nym_sphinx_acknowledgements::AckKey; use nym_sphinx_addressing::Recipient; -use nym_task::{ShutdownManager, ShutdownToken, TaskClient}; +use nym_task::{ShutdownManager, ShutdownToken, ShutdownTracker}; use nym_validator_client::UserAgent; use nym_verloc::measurements::SharedVerlocStats; use nym_verloc::{self, measurements::VerlocMeasurer}; @@ -465,19 +465,21 @@ impl NymNode { wireguard: Some(wireguard_data), config, accepted_operator_terms_and_conditions: false, - shutdown_manager: ShutdownManager::new("NymNode") - .with_legacy_task_manager() - .with_default_shutdown_signals() + shutdown_manager: ShutdownManager::build_new_default() .map_err(|source| NymNodeError::ShutdownSignalFailure { source })?, }) } - pub(crate) fn config(&self) -> &Config { - &self.config + pub(crate) fn shutdown_tracker(&self) -> &ShutdownTracker { + self.shutdown_manager.shutdown_tracker() } - pub(crate) fn shutdown_token>(&self, child_suffix: S) -> ShutdownToken { - self.shutdown_manager.clone_token(child_suffix) + pub(crate) fn shutdown_token(&self) -> ShutdownToken { + self.shutdown_manager.clone_shutdown_token() + } + + pub(crate) fn config(&self) -> &Config { + &self.config } pub(crate) fn with_accepted_operator_terms_and_conditions( @@ -561,7 +563,7 @@ impl NymNode { self.config.mixnet.nym_api_urls.clone(), self.config.debug.topology_cache_ttl, self.config.debug.routing_nodes_check_interval, - self.shutdown_manager.clone_token("network-refresher"), + self.shutdown_manager.clone_shutdown_token(), ) .await } @@ -605,8 +607,6 @@ impl NymNode { metrics_sender: MetricEventsSender, active_clients_store: ActiveClientsStore, mix_packet_sender: MixForwardingSender, - legacy_task_client: TaskClient, - shutdown_token: ShutdownToken, ) -> Result<(), NymNodeError> { let config = gateway_tasks_config(&self.config); @@ -624,8 +624,7 @@ impl NymNode { metrics_sender, self.metrics.clone(), self.entry_gateway.mnemonic.clone(), - legacy_task_client, - shutdown_token, + self.shutdown_tracker().clone(), ); // if we're running in entry mode, start the websocket @@ -634,10 +633,11 @@ impl NymNode { "starting the clients websocket... on {}", self.config.gateway_tasks.ws_bind_address ); - let websocket = gateway_tasks_builder + let mut websocket = gateway_tasks_builder .build_websocket_listener(active_clients_store.clone()) .await?; - websocket.start(); + self.shutdown_tracker() + .try_spawn_named(async move { websocket.run().await }, "EntryWebsocket"); } else { info!("node not running in entry mode: the websocket will remain closed"); } @@ -697,8 +697,12 @@ impl NymNode { } // start task for removing stale and un-retrieved client messages - let stale_messages_cleaner = gateway_tasks_builder.build_stale_messages_cleaner(); - stale_messages_cleaner.start(); + let mut stale_messages_cleaner = gateway_tasks_builder.build_stale_messages_cleaner(); + let shutdown_token = self.shutdown_token(); + self.shutdown_tracker().try_spawn_named( + async move { stale_messages_cleaner.run(shutdown_token).await }, + "StaleMessagesCleaner", + ); Ok(()) } @@ -875,25 +879,23 @@ impl NymNode { let mut verloc_measurer = VerlocMeasurer::new( config, self.ed25519_identity_keys.clone(), - self.shutdown_manager.clone_token("verloc"), + self.shutdown_manager.clone_shutdown_token(), ); verloc_measurer.set_shared_state(self.verloc_stats.clone()); - tokio::spawn(async move { verloc_measurer.run().await }); + self.shutdown_manager + .try_spawn_named(async move { verloc_measurer.run().await }, "VerlocMeasurer"); } pub(crate) fn setup_metrics_backend( &self, active_clients_store: ActiveClientsStore, active_egress_mixnet_connections: ActiveConnections, - shutdown: ShutdownToken, ) -> MetricEventsSender { info!("setting up node metrics..."); // aggregator (to listen for any metrics events) - let mut metrics_aggregator = MetricsAggregator::new( - self.config.metrics.debug.aggregator_update_rate, - shutdown.clone_with_suffix("aggregator"), - ); + let mut metrics_aggregator = + MetricsAggregator::new(self.config.metrics.debug.aggregator_update_rate); // >>>> START: register all relevant handlers for custom events @@ -950,18 +952,25 @@ impl NymNode { // console logger to preserve old mixnode functionalities if self.config.metrics.debug.log_stats_to_console { - ConsoleLogger::new( + let mut console_logger = ConsoleLogger::new( self.config.metrics.debug.console_logging_update_interval, self.metrics.clone(), - shutdown.clone_with_suffix("metrics-console-logger"), - ) - .start(); + ); + + self.shutdown_tracker().try_spawn_named_with_shutdown( + async move { console_logger.run().await }, + "ConsoleLogger", + ); } let events_sender = metrics_aggregator.sender(); // spawn the aggregator task - metrics_aggregator.start(); + let shutdown_token = self.shutdown_token(); + self.shutdown_tracker().try_spawn_named( + async move { metrics_aggregator.run(shutdown_token).await }, + "MetricsAggregator", + ); events_sender } @@ -983,14 +992,15 @@ impl NymNode { sphinx_keys.keys.primary_key_rotation_id(), sphinx_keys.keys.secondary_key_rotation_id(), self.metrics.clone(), - self.shutdown_manager - .clone_token("replay-detection-background-flush"), + self.shutdown_manager.clone_shutdown_token(), ) .await?; let bloomfilters_manager = replay_detection_background.bloomfilters_manager(); - self.shutdown_manager - .spawn(async move { replay_detection_background.run().await }); + self.shutdown_manager.try_spawn_named( + async move { replay_detection_background.run().await }, + "ReplayDetection", + ); Ok(bloomfilters_manager) } @@ -998,7 +1008,7 @@ impl NymNode { fn setup_nym_apis_client(&self) -> Result { NymApisClient::new( &self.config.mixnet.nym_api_urls, - self.shutdown_manager.clone_token("nym-apis-client"), + self.shutdown_manager.clone_shutdown_token(), ) } @@ -1029,7 +1039,7 @@ impl NymNode { nym_apis_client, replay_protection_manager, managed_keys, - self.shutdown_manager.clone_token("key-rotation-controller"), + self.shutdown_manager.clone_shutdown_token(), ); rotation_controller.start(); @@ -1042,7 +1052,6 @@ impl NymNode { replay_protection_bloomfilter: ReplayProtectionBloomfilters, routing_filter: F, noise_config: NoiseConfig, - shutdown: ShutdownToken, ) -> Result<(MixForwardingSender, ActiveConnections), NymNodeError> where F: RoutingFilter + Send + Sync + 'static, @@ -1073,14 +1082,16 @@ impl NymNode { ); let active_connections = mixnet_client.active_connections(); - let mut packet_forwarder = PacketForwarder::new( - mixnet_client, - routing_filter, - self.metrics.clone(), - shutdown.clone_with_suffix("mix-packet-forwarder"), - ); + let mut packet_forwarder = + PacketForwarder::new(mixnet_client, routing_filter, self.metrics.clone()); let mix_packet_sender = packet_forwarder.sender(); - tokio::spawn(async move { packet_forwarder.run().await }); + + let shutdown_token = self.shutdown_token(); + + self.shutdown_tracker().try_spawn_named( + async move { packet_forwarder.run(shutdown_token).await }, + "PacketForwarder", + ); let final_hop_data = SharedFinalHopData::new( active_clients_store.clone(), @@ -1095,14 +1106,21 @@ impl NymNode { final_hop_data, noise_config, self.metrics.clone(), - shutdown, + self.shutdown_token(), + ); + + let mut mixnet_listener = mixnet::Listener::new(self.config.mixnet.bind_address, shared); + + let shutdown_token = self.shutdown_token(); + self.shutdown_tracker().try_spawn_named( + async move { mixnet_listener.run(shutdown_token).await }, + "MixnetListener", ); - mixnet::Listener::new(self.config.mixnet.bind_address, shared).start(); Ok((mix_packet_sender, active_connections)) } - pub(crate) async fn run_minimal_mixnet_processing(self) -> Result<(), NymNodeError> { + pub(crate) async fn run_minimal_mixnet_processing(mut self) -> Result<(), NymNodeError> { let noise_config = nym_noise::config::NoiseConfig::new( self.x25519_noise_keys.clone(), NoiseNetworkView::new_empty(), @@ -1115,11 +1133,10 @@ impl NymNode { ReplayProtectionBloomfilters::new_disabled(), OpenFilter, noise_config, - self.shutdown_manager.clone_token("mixnet-traffic"), ) .await?; - self.shutdown_manager.close(); + self.shutdown_manager.close_tracker(); self.shutdown_manager.run_until_shutdown().await; Ok(()) @@ -1137,16 +1154,17 @@ impl NymNode { let http_server = self.build_http_server().await?; let bind_address = self.config.http.bind_address; - let server_shutdown = self.shutdown_manager.clone_token("http-server"); + let server_shutdown = self.shutdown_manager.clone_shutdown_token(); - self.shutdown_manager.spawn(async move { - { + self.shutdown_manager.try_spawn_named( + async move { info!("starting NymNodeHTTPServer on {bind_address}"); http_server .with_graceful_shutdown(async move { server_shutdown.cancelled().await }) .await - } - }); + }, + "HttpApi", + ); let nym_apis_client = self.setup_nym_apis_client()?; @@ -1172,14 +1190,12 @@ impl NymNode { bloomfilters_manager.bloomfilters(), network_refresher.routing_filter(), noise_config, - self.shutdown_manager.clone_token("mixnet-traffic"), ) .await?; let metrics_sender = self.setup_metrics_backend( active_clients_store.clone(), active_egress_mixnet_connections, - self.shutdown_manager.clone_token("metrics"), ); self.start_gateway_tasks( @@ -1187,8 +1203,6 @@ impl NymNode { metrics_sender, active_clients_store, mix_packet_sender, - self.shutdown_manager.subscribe_legacy("gateway-tasks"), - self.shutdown_manager.child_token("gateway-tasks"), ) .await?; @@ -1196,7 +1210,7 @@ impl NymNode { .await?; network_refresher.start(); - self.shutdown_manager.close(); + self.shutdown_manager.close_tracker(); Ok(self.shutdown_manager) } diff --git a/nym-node/src/node/shared_network.rs b/nym-node/src/node/shared_network.rs index 7fb88b31b0a..f6a8164862a 100644 --- a/nym-node/src/node/shared_network.rs +++ b/nym-node/src/node/shared_network.rs @@ -378,11 +378,12 @@ impl NetworkRefresher { let mut pending_check_interval = interval(self.pending_check_interval); pending_check_interval.reset(); - while !self.shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = self.shutdown_token.cancelled() => { - trace!("NetworkRefresher: Received shutdown"); + trace!("NetworkRefresher: Received shutdown"); + break; } _ = pending_check_interval.tick() => { self.inspect_pending().await; diff --git a/nym-node/src/throughput_tester/mod.rs b/nym-node/src/throughput_tester/mod.rs index aaa23607269..81b8f38c4de 100644 --- a/nym-node/src/throughput_tester/mod.rs +++ b/nym-node/src/throughput_tester/mod.rs @@ -132,7 +132,7 @@ pub(crate) fn test_mixing_throughput( let mut tasks_handles = Vec::new(); for (sender_id, stats) in stats.iter().enumerate() { - let token = nym_node.shutdown_token(format!("dummy-load-client-{sender_id}")); + let token = nym_node.shutdown_token(); let client_future = run_testing_client( sender_id, @@ -152,7 +152,7 @@ pub(crate) fn test_mixing_throughput( header_span, stats, output_directory, - nym_node.shutdown_token("global-stats"), + nym_node.shutdown_token(), ); let stats_handle = tester.clients_runtime.spawn(async move { diff --git a/nym-signers-monitor/src/monitor.rs b/nym-signers-monitor/src/monitor.rs index caae9cbe1f6..e1136f77183 100644 --- a/nym-signers-monitor/src/monitor.rs +++ b/nym-signers-monitor/src/monitor.rs @@ -190,15 +190,15 @@ impl SignersMonitor { } pub(crate) async fn run(&mut self) -> anyhow::Result<()> { - let shutdown_manager = - ShutdownManager::new("nym-signers-monitor").with_default_shutdown_signals()?; + let mut shutdown_manager = ShutdownManager::build_new_default()?; + let root_token = shutdown_manager.clone_shutdown_token(); let mut check_interval = interval(self.check_interval); - while !shutdown_manager.root_token.is_cancelled() { + while !root_token.is_cancelled() { tokio::select! { biased; - _ = shutdown_manager.root_token.cancelled() => { + _ = root_token.cancelled() => { info!("received shutdown"); break; } @@ -211,7 +211,6 @@ impl SignersMonitor { } } - shutdown_manager.close(); shutdown_manager.run_until_shutdown().await; if let Err(err) = self.send_shutdown_notification().await { diff --git a/nym-statistics-api/src/main.rs b/nym-statistics-api/src/main.rs index 4af7c83a363..c3e8ec919bb 100644 --- a/nym-statistics-api/src/main.rs +++ b/nym-statistics-api/src/main.rs @@ -27,19 +27,17 @@ async fn main() -> anyhow::Result<()> { .await?; tracing::info!("Connection to database successful"); - let shutdown_manager = ShutdownManager::new("nym-statistics-api"); + let mut shutdown_manager = ShutdownManager::build_new_default()?; - let network_refresher = NetworkRefresher::initialise_new( - args.nym_api_url, - shutdown_manager.child_token("network-refresher"), - ) - .await; + let network_refresher = + NetworkRefresher::initialise_new(args.nym_api_url, shutdown_manager.child_shutdown_token()) + .await; let http_server = http::server::build_http_api(storage, network_refresher.network_view(), args.http_port) .await .expect("Failed to build http server"); - let server_shutdown = shutdown_manager.clone_token("http-api-server"); + let server_shutdown = shutdown_manager.clone_shutdown_token(); // Starting tasks shutdown_manager.spawn(async move { http_server.run(server_shutdown).await }); @@ -47,7 +45,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!("Started HTTP server on port {}", args.http_port); - shutdown_manager.close(); + shutdown_manager.close_tracker(); shutdown_manager.run_until_shutdown().await; Ok(()) diff --git a/nym-statistics-api/src/network_view.rs b/nym-statistics-api/src/network_view.rs index ca8352d6b0c..c8e41f9623a 100644 --- a/nym-statistics-api/src/network_view.rs +++ b/nym-statistics-api/src/network_view.rs @@ -142,11 +142,12 @@ impl NetworkRefresher { let mut full_refresh_interval = interval(self.full_refresh_interval); full_refresh_interval.reset(); - while !self.shutdown_token.is_cancelled() { + loop { tokio::select! { biased; _ = self.shutdown_token.cancelled() => { trace!("NetworkRefresher: Received shutdown"); + break; } _ = full_refresh_interval.tick() => { if self.refresh_network_nodes().await.is_err() { diff --git a/nym-validator-rewarder/src/error.rs b/nym-validator-rewarder/src/error.rs index 48150139dfe..211d9b05221 100644 --- a/nym-validator-rewarder/src/error.rs +++ b/nym-validator-rewarder/src/error.rs @@ -15,6 +15,9 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum NymRewarderError { + #[error(transparent)] + IoFailure(#[from] io::Error), + #[error("experienced internal database error: {0}")] InternalDatabaseError(#[from] sqlx::Error), diff --git a/nym-validator-rewarder/src/rewarder/mod.rs b/nym-validator-rewarder/src/rewarder/mod.rs index ed38ba1d3c7..3674989fa2d 100644 --- a/nym-validator-rewarder/src/rewarder/mod.rs +++ b/nym-validator-rewarder/src/rewarder/mod.rs @@ -14,12 +14,11 @@ use futures::future::{FusedFuture, OptionFuture}; use futures::FutureExt; use nym_crypto::asymmetric::ed25519; use nym_ecash_time::{ecash_today, ecash_today_date, EcashTime}; -use nym_task::TaskManager; +use nym_task::ShutdownManager; use nym_validator_client::nyxd::{AccountId, Coin, Hash}; use nyxd_scraper::NyxdScraper; use std::sync::Arc; use time::Date; -use tokio::pin; use tracing::{error, info, instrument, warn}; pub(crate) mod block_signing; @@ -540,7 +539,7 @@ impl Rewarder { async fn main_loop( mut self, - mut task_manager: TaskManager, + mut shutdown_manager: ShutdownManager, mut scraper_cancellation: impl FusedFuture + Unpin, ) { let mut block_signing_epoch_ticker = self @@ -550,17 +549,13 @@ impl Rewarder { // runs daily let mut ticketbook_issuance_ticker = end_of_day_ticker(); - let shutdown_future = task_manager.catch_interrupt(); - pin!(shutdown_future); + let mut shutdown_signals = shutdown_manager.detach_shutdown_signals(); loop { tokio::select! { biased; - interrupt_res = &mut shutdown_future => { + _ = shutdown_signals.wait_for_signal() => { info!("received interrupt"); - if let Err(err) = interrupt_res { - error!("runtime interrupt failure: {err}") - } break; } _ = &mut scraper_cancellation, if !scraper_cancellation.is_terminated() => { @@ -575,13 +570,19 @@ impl Rewarder { if let Some(epoch_signing) = self.epoch_signing { epoch_signing.nyxd_scraper.stop().await; } + + // in case we received cancellation from the scraper, kill other tasks (currently none) + if !shutdown_manager.is_cancelled() { + shutdown_manager.send_cancellation(); + } + shutdown_manager.run_until_shutdown().await } pub async fn run(mut self) -> Result<(), NymRewarderError> { info!("Starting nym validators rewarder"); // setup shutdowns - let task_manager = TaskManager::new(5); + let shutdown_manager = ShutdownManager::build_new_default()?; let scraper_cancellation = self.setup_tasks().await?; if let Err(err) = self.startup_resync().await { @@ -598,7 +599,7 @@ impl Rewarder { return Err(err); } - self.main_loop(task_manager, scraper_cancellation).await; + self.main_loop(shutdown_manager, scraper_cancellation).await; Ok(()) } diff --git a/nym-wallet/Cargo.lock b/nym-wallet/Cargo.lock index 26ccca2b9aa..248771db3fc 100644 --- a/nym-wallet/Cargo.lock +++ b/nym-wallet/Cargo.lock @@ -4136,6 +4136,7 @@ dependencies = [ name = "nym-crypto" version = "0.4.0" dependencies = [ + "base64 0.22.1", "bs58", "ed25519-dalek", "nym-pemstore", diff --git a/sdk/rust/nym-sdk/src/lib.rs b/sdk/rust/nym-sdk/src/lib.rs index f79798e32ea..8268a039c3d 100644 --- a/sdk/rust/nym-sdk/src/lib.rs +++ b/sdk/rust/nym-sdk/src/lib.rs @@ -26,7 +26,5 @@ pub use nym_network_defaults::{ ChainDetails, DenomDetails, DenomDetailsOwned, NymContracts, NymNetworkDetails, ValidatorDetails, }; +pub use nym_task::ShutdownToken; pub use nym_validator_client::UserAgent; -// we have to re-expose TaskClient since we're allowing custom shutdown in public API -// (which is quite a shame if you ask me...) -pub use nym_task::TaskClient; diff --git a/sdk/rust/nym-sdk/src/mixnet/client.rs b/sdk/rust/nym-sdk/src/mixnet/client.rs index 90d1b218b66..1949ef534a4 100644 --- a/sdk/rust/nym-sdk/src/mixnet/client.rs +++ b/sdk/rust/nym-sdk/src/mixnet/client.rs @@ -8,8 +8,6 @@ use crate::mixnet::{CredentialStorage, MixnetClient, Recipient}; use crate::GatewayTransceiver; use crate::NymNetworkDetails; use crate::{Error, Result}; -use futures::channel::mpsc; -use futures::StreamExt; use log::{debug, warn}; use nym_client_core::client::base_client::storage::helpers::{ get_active_gateway_identity, get_all_registered_identities, has_gateway_details, @@ -31,7 +29,7 @@ use nym_client_core::init::types::{GatewaySelectionSpecification, GatewaySetup}; use nym_credentials_interface::TicketType; use nym_crypto::hkdf::DerivationMaterial; use nym_socks5_client_core::config::Socks5; -use nym_task::{TaskClient, TaskHandle, TaskStatus}; +use nym_task::ShutdownTracker; use nym_topology::provider_trait::TopologyProvider; use nym_validator_client::{nyxd, QueryHttpRpcNyxdClient, UserAgent}; use rand::rngs::OsRng; @@ -54,7 +52,7 @@ pub struct MixnetClientBuilder { wait_for_gateway: bool, custom_topology_provider: Option>, custom_gateway_transceiver: Option>, - custom_shutdown: Option, + custom_shutdown: Option, force_tls: bool, user_agent: Option, #[cfg(unix)] @@ -266,7 +264,7 @@ where /// Use an externally managed shutdown mechanism. #[must_use] - pub fn custom_shutdown(mut self, shutdown: TaskClient) -> Self { + pub fn custom_shutdown(mut self, shutdown: ShutdownTracker) -> Self { self.custom_shutdown = Some(shutdown); self } @@ -380,7 +378,7 @@ where force_tls: bool, /// Allows passing an externally controlled shutdown handle. - custom_shutdown: Option, + custom_shutdown: Option, user_agent: Option, @@ -740,7 +738,13 @@ where let debug_config = self.config.debug_config; let packet_type = self.config.debug_config.traffic.packet_type; let (mut started_client, nym_address) = self.connect_to_mixnet_common().await?; - let (socks5_status_tx, mut socks5_status_rx) = mpsc::channel(128); + + // TODO: more graceful handling here, surely both variants should work... I think? + let Some(task_manager) = started_client.shutdown_handle else { + return Err(Error::new_unsupported( + "connecting with socks5 is currently unsupported with custom shutdown", + )); + }; let client_input = started_client.client_input.register_producer(); let client_output = started_client.client_output.register_consumer(); @@ -753,40 +757,14 @@ where client_output, client_state.clone(), nym_address, - started_client.task_handle.get_handle(), + task_manager.shutdown_tracker_owned(), packet_type, ); - // TODO: more graceful handling here, surely both variants should work... I think? - if let TaskHandle::Internal(task_manager) = &mut started_client.task_handle { - task_manager - .start_status_listener(socks5_status_tx, TaskStatus::Ready) - .await; - match socks5_status_rx - .next() - .await - .ok_or(Error::Socks5NotStarted)? - .as_any() - .downcast_ref::() - .ok_or(Error::Socks5NotStarted)? - { - TaskStatus::Ready => { - log::debug!("Socks5 connected"); - } - TaskStatus::ReadyWithGateway(gateway) => { - log::debug!("Socks5 connected to {gateway}"); - } - } - } else { - return Err(Error::new_unsupported( - "connecting with socks5 is currently unsupported with custom shutdown", - )); - } - Ok(Socks5MixnetClient { nym_address, client_state, - task_handle: started_client.task_handle, + task_handle: task_manager, socks5_config, }) } @@ -833,7 +811,7 @@ where client_state, reconstructed_receiver, stats_events_reporter, - started_client.task_handle, + started_client.shutdown_handle, None, started_client.client_request_sender, started_client.forget_me, diff --git a/sdk/rust/nym-sdk/src/mixnet/native_client.rs b/sdk/rust/nym-sdk/src/mixnet/native_client.rs index 7c8089d3243..d9c92030f46 100644 --- a/sdk/rust/nym-sdk/src/mixnet/native_client.rs +++ b/sdk/rust/nym-sdk/src/mixnet/native_client.rs @@ -19,7 +19,7 @@ use nym_sphinx::{params::PacketType, receiver::ReconstructedMessage}; use nym_statistics_common::clients::{ClientStatsEvents, ClientStatsSender}; use nym_task::{ connections::{ConnectionCommandSender, LaneQueueLengths}, - TaskHandle, + ShutdownManager, }; use nym_topology::{NymRouteProvider, NymTopology}; use std::pin::Pin; @@ -54,7 +54,7 @@ pub struct MixnetClient { pub(crate) stats_events_reporter: ClientStatsSender, /// The task manager that controls all the spawned tasks that the clients uses to do it's job. - pub(crate) task_handle: TaskHandle, + pub(crate) shutdown_handle: Option, pub(crate) packet_type: Option, // internal state used for the `Stream` implementation @@ -74,7 +74,7 @@ impl MixnetClient { client_state: ClientState, reconstructed_receiver: ReconstructedMessagesReceiver, stats_events_reporter: ClientStatsSender, - task_handle: TaskHandle, + task_handle: Option, packet_type: Option, client_request_sender: ClientRequestSender, forget_me: ForgetMe, @@ -88,7 +88,7 @@ impl MixnetClient { client_state, reconstructed_receiver, stats_events_reporter, - task_handle, + shutdown_handle: task_handle, packet_type, _buffered: Vec::new(), client_request_sender, @@ -221,9 +221,9 @@ impl MixnetClient { self.stats_events_reporter.clone() } - /// Disconnect from the mixnet. Currently it is not supported to reconnect a disconnected + /// Disconnect from the mixnet. Currently, it is not supported to reconnect a disconnected /// client. - pub async fn disconnect(mut self) { + pub async fn disconnect(self) { if self.forget_me.any() { log::debug!("Sending forget me request: {:?}", self.forget_me); match self.send_forget_me().await { @@ -240,13 +240,9 @@ impl MixnetClient { tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; } - if let TaskHandle::Internal(task_manager) = &mut self.task_handle { - task_manager.signal_shutdown().ok(); - task_manager.wait_for_shutdown().await; + if let Some(mut task_manager) = self.shutdown_handle { + task_manager.perform_shutdown().await; } - - // note: it's important to take ownership of the struct as if the shutdown is `TaskHandle::External`, - // it must be dropped to finalize the shutdown } pub async fn send_forget_me(&self) -> Result<()> { diff --git a/sdk/rust/nym-sdk/src/mixnet/socks5_client.rs b/sdk/rust/nym-sdk/src/mixnet/socks5_client.rs index 6cbf837c296..cb4e826bd56 100644 --- a/sdk/rust/nym-sdk/src/mixnet/socks5_client.rs +++ b/sdk/rust/nym-sdk/src/mixnet/socks5_client.rs @@ -1,7 +1,7 @@ use nym_client_core::client::base_client::ClientState; use nym_socks5_client_core::config::Socks5; use nym_sphinx::addressing::clients::Recipient; -use nym_task::{connections::LaneQueueLengths, TaskHandle}; +use nym_task::{connections::LaneQueueLengths, ShutdownManager}; use nym_topology::NymTopology; @@ -18,7 +18,7 @@ pub struct Socks5MixnetClient { pub(crate) client_state: ClientState, /// The task manager that controls all the spawned tasks that the clients uses to do it's job. - pub(crate) task_handle: TaskHandle, + pub(crate) task_handle: ShutdownManager, /// SOCKS5 configuration parameters. pub(crate) socks5_config: Socks5, @@ -82,12 +82,6 @@ impl Socks5MixnetClient { /// Disconnect from the mixnet. Currently it is not supported to reconnect a disconnected /// client. pub async fn disconnect(mut self) { - if let TaskHandle::Internal(task_manager) = &mut self.task_handle { - task_manager.signal_shutdown().ok(); - task_manager.wait_for_shutdown().await; - } - - // note: it's important to take ownership of the struct as if the shutdown is `TaskHandle::External`, - // it must be dropped to finalize the shutdown + self.task_handle.run_until_shutdown().await; } } diff --git a/sdk/rust/nym-sdk/src/tcp_proxy/tcp_proxy_client.rs b/sdk/rust/nym-sdk/src/tcp_proxy/tcp_proxy_client.rs index 85f9682e629..f00a5b30c14 100644 --- a/sdk/rust/nym-sdk/src/tcp_proxy/tcp_proxy_client.rs +++ b/sdk/rust/nym-sdk/src/tcp_proxy/tcp_proxy_client.rs @@ -223,7 +223,7 @@ impl NymProxyClient { loop { tokio::select! { _ = &mut rx => { - info!("Closing write end of session: {} in {} seconds", session_id, close_timeout); + info!("Closing write end of session: {session_id} in {close_timeout} seconds"); break } Some(message) = client.next() => { diff --git a/service-providers/ip-packet-router/src/cli/run.rs b/service-providers/ip-packet-router/src/cli/run.rs index bf71e79dea0..2108efcc86e 100644 --- a/service-providers/ip-packet-router/src/cli/run.rs +++ b/service-providers/ip-packet-router/src/cli/run.rs @@ -3,6 +3,7 @@ use crate::cli::{override_config, OverrideConfig}; use clap::Args; use nym_client_core::cli_helpers::client_run::CommonClientRunArgs; use nym_ip_packet_router::error::IpPacketRouterError; +use nym_task::ShutdownManager; #[allow(clippy::struct_excessive_bools)] #[derive(Args, Clone)] @@ -27,10 +28,17 @@ pub(crate) async fn execute(args: &Run) -> Result<(), IpPacketRouterError> { log::debug!("Using config: {config:#?}"); log::info!("Starting ip packet router service provider"); - let mut server = nym_ip_packet_router::IpPacketRouter::new(config); + let mut shutdown_manager = ShutdownManager::build_new_default()?; + let mut server = nym_ip_packet_router::IpPacketRouter::new( + config, + shutdown_manager.shutdown_tracker_owned(), + ); if let Some(custom_mixnet) = &args.common_args.custom_mixnet { server = server.with_stored_topology(custom_mixnet)? } - server.run_service_provider().await + tokio::spawn(server.run_service_provider()); + + shutdown_manager.run_until_shutdown().await; + Ok(()) } diff --git a/service-providers/ip-packet-router/src/ip_packet_router.rs b/service-providers/ip-packet-router/src/ip_packet_router.rs index 4e976e96f97..1e38d335952 100644 --- a/service-providers/ip-packet-router/src/ip_packet_router.rs +++ b/service-providers/ip-packet-router/src/ip_packet_router.rs @@ -12,7 +12,7 @@ use nym_client_core::{ TopologyProvider, }; use nym_sdk::mixnet::Recipient; -use nym_task::{TaskClient, TaskHandle}; +use nym_task::ShutdownTracker; use crate::{config::Config, error::IpPacketRouterError, request_filter::RequestFilter}; @@ -39,29 +39,22 @@ pub struct IpPacketRouter { wait_for_gateway: bool, custom_topology_provider: Option>, custom_gateway_transceiver: Option>, - shutdown: Option, + shutdown: ShutdownTracker, on_start: Option>, } impl IpPacketRouter { - pub fn new(config: Config) -> Self { + pub fn new(config: Config, shutdown: ShutdownTracker) -> Self { Self { config, wait_for_gateway: false, custom_topology_provider: None, custom_gateway_transceiver: None, - shutdown: None, + shutdown, on_start: None, } } - #[must_use] - #[allow(unused)] - pub fn with_shutdown(mut self, shutdown: TaskClient) -> Self { - self.shutdown = Some(shutdown); - self - } - #[must_use] #[allow(unused)] pub fn with_custom_gateway_transceiver( @@ -130,12 +123,11 @@ impl IpPacketRouter { clients::ConnectedClients, mixnet_listener::MixnetListener, request_filter::RequestFilter, tun_listener::TunListener, }; - let task_handle: TaskHandle = self.shutdown.map(Into::into).unwrap_or_default(); // Connect to the mixnet let mixnet_client = crate::mixnet_client::create_mixnet_client( &self.config.base, - task_handle.get_handle().named("nym_sdk::MixnetClient[IPR]"), + self.shutdown.clone(), self.custom_gateway_transceiver, self.custom_topology_provider, self.wait_for_gateway, @@ -163,7 +155,7 @@ impl IpPacketRouter { let tun_listener = TunListener { tun_reader, - task_client: task_handle.get_handle(), + shutdown_token: self.shutdown.clone_shutdown_token(), connected_clients: connected_clients_rx, }; tun_listener.start(); @@ -176,7 +168,7 @@ impl IpPacketRouter { request_filter: request_filter.clone(), tun_writer, mixnet_client, - task_handle, + shutdown_token: self.shutdown.clone_shutdown_token(), connected_clients, }; diff --git a/service-providers/ip-packet-router/src/mixnet_client.rs b/service-providers/ip-packet-router/src/mixnet_client.rs index 18bed08c24f..8326a5854a0 100644 --- a/service-providers/ip-packet-router/src/mixnet_client.rs +++ b/service-providers/ip-packet-router/src/mixnet_client.rs @@ -3,7 +3,7 @@ use nym_client_core::{config::disk_persistence::CommonClientPaths, TopologyProvider}; use nym_sdk::{GatewayTransceiver, NymNetworkDetails}; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use crate::{config::BaseClientConfig, error::IpPacketRouterError}; @@ -13,7 +13,7 @@ use crate::{config::BaseClientConfig, error::IpPacketRouterError}; // TODO: refactor this function and its arguments pub(crate) async fn create_mixnet_client( config: &BaseClientConfig, - shutdown: TaskClient, + shutdown: ShutdownTracker, custom_transceiver: Option>, custom_topology_provider: Option>, wait_for_gateway: bool, diff --git a/service-providers/ip-packet-router/src/mixnet_listener.rs b/service-providers/ip-packet-router/src/mixnet_listener.rs index fe622f8aa03..b29493da8f1 100644 --- a/service-providers/ip-packet-router/src/mixnet_listener.rs +++ b/service-providers/ip-packet-router/src/mixnet_listener.rs @@ -1,15 +1,6 @@ // Copyright 2025 - Nym Technologies SA // SPDX-License-Identifier: GPL-3.0-only -use futures::StreamExt; -use nym_ip_packet_requests::codec::MultiIpPacketCodec; -use nym_sdk::mixnet::MixnetMessageSender; -use nym_sphinx::receiver::ReconstructedMessage; -use nym_task::TaskHandle; -use std::{net::SocketAddr, time::Duration}; -use tokio::io::AsyncWriteExt; -use tokio_util::codec::FramedRead; - use crate::{ clients::{ConnectedClientHandler, ConnectedClients}, config::Config, @@ -30,6 +21,14 @@ use crate::{ request_filter::RequestFilter, util::parse_ip::ParsedPacket, }; +use futures::StreamExt; +use nym_ip_packet_requests::codec::MultiIpPacketCodec; +use nym_sdk::mixnet::MixnetMessageSender; +use nym_sphinx::receiver::ReconstructedMessage; +use nym_task::ShutdownToken; +use std::{net::SocketAddr, time::Duration}; +use tokio::io::AsyncWriteExt; +use tokio_util::codec::FramedRead; #[cfg(not(target_os = "linux"))] type TunDevice = crate::non_linux_dummy::DummyDevice; @@ -52,7 +51,7 @@ pub(crate) struct MixnetListener { pub(crate) mixnet_client: nym_sdk::mixnet::MixnetClient, // The task handle for the main loop - pub(crate) task_handle: TaskHandle, + pub(crate) shutdown_token: ShutdownToken, // The map of connected clients that the mixnet listener keeps track of. It monitors // activity and disconnects clients that have been inactive for too long. @@ -138,7 +137,7 @@ impl MixnetListener { Ok(responses) } - // Receving a static connect request from a client with an IP provided that we assign to them, + // Receiving a static connect request from a client with an IP provided that we assign to them, // if it's available. If it's not available, we send a failure response. async fn on_static_connect_request( &mut self, @@ -463,13 +462,14 @@ impl MixnetListener { } pub(crate) async fn run(mut self) -> Result<()> { - let mut task_client = self.task_handle.fork("main_loop"); let mut disconnect_timer = tokio::time::interval(DISCONNECT_TIMER_INTERVAL); - while !task_client.is_shutdown() { + loop { tokio::select! { - _ = task_client.recv() => { + biased; + _ = self.shutdown_token.cancelled() => { log::debug!("IpPacketRouter [main loop]: received shutdown"); + break; }, _ = disconnect_timer.tick() => { self.handle_disconnect_timer().await; diff --git a/service-providers/ip-packet-router/src/tun_listener.rs b/service-providers/ip-packet-router/src/tun_listener.rs index 072ade7ab03..f1b663fc033 100644 --- a/service-providers/ip-packet-router/src/tun_listener.rs +++ b/service-providers/ip-packet-router/src/tun_listener.rs @@ -4,15 +4,14 @@ use std::collections::HashMap; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use crate::clients::{ConnectEvent, ConnectedClientEvent, DisconnectEvent}; +use crate::{error::Result, util::parse_ip::parse_dst_addr}; use nym_ip_packet_requests::IpPair; -use nym_task::TaskClient; +use nym_task::ShutdownToken; #[cfg(target_os = "linux")] use tokio::io::AsyncReadExt; use tokio::sync::mpsc; -use crate::clients::{ConnectEvent, ConnectedClientEvent, DisconnectEvent}; -use crate::{error::Result, util::parse_ip::parse_dst_addr}; - // The TUN listener keeps a local map of the connected clients that has its state updated by the // mixnet listener. Basically it's just so that we don't have to have mutexes around shared state. // It's even ok if this is slightly out of date @@ -79,7 +78,7 @@ impl ConnectedClientsListener { #[cfg(target_os = "linux")] pub(crate) struct TunListener { pub(crate) tun_reader: tokio::io::ReadHalf, - pub(crate) task_client: TaskClient, + pub(crate) shutdown_token: ShutdownToken, pub(crate) connected_clients: ConnectedClientsListener, } @@ -113,10 +112,12 @@ impl TunListener { async fn run(mut self) -> Result<()> { let mut buf = [0u8; 65535]; - while !self.task_client.is_shutdown() { + loop { tokio::select! { - _ = self.task_client.recv() => { + biased; + _ = self.shutdown_token.cancelled() => { log::trace!("TunListener: received shutdown"); + break; }, // TODO: ConnectedClientsListener::update should poll the channel instead event = self.connected_clients.connected_client_rx.recv() => match event { diff --git a/service-providers/network-requester/src/cli/run.rs b/service-providers/network-requester/src/cli/run.rs index 09938ac46a4..3cb96948d0b 100644 --- a/service-providers/network-requester/src/cli/run.rs +++ b/service-providers/network-requester/src/cli/run.rs @@ -8,6 +8,7 @@ use crate::{ }; use clap::Args; use nym_client_core::cli_helpers::client_run::CommonClientRunArgs; +use nym_task::ShutdownManager; #[allow(clippy::struct_excessive_bools)] #[derive(Args, Clone)] @@ -58,10 +59,17 @@ pub(crate) async fn execute(args: &Run) -> Result<(), NetworkRequesterError> { } log::info!("Starting socks5 service provider"); - let mut server = crate::core::NRServiceProviderBuilder::new(config); + let mut shutdown_manager = ShutdownManager::build_new_default()?; + let mut server = crate::core::NRServiceProviderBuilder::new( + config, + shutdown_manager.shutdown_tracker_owned(), + ); if let Some(custom_mixnet) = &args.common_args.custom_mixnet { server = server.with_stored_topology(custom_mixnet)? } - server.run_service_provider().await + tokio::spawn(server.run_service_provider()); + + shutdown_manager.run_until_shutdown().await; + Ok(()) } diff --git a/service-providers/network-requester/src/core.rs b/service-providers/network-requester/src/core.rs index 513e187d646..3bf865a91b8 100644 --- a/service-providers/network-requester/src/core.rs +++ b/service-providers/network-requester/src/core.rs @@ -9,7 +9,7 @@ use crate::{reply, socks5}; use async_trait::async_trait; use futures::channel::{mpsc, oneshot}; use futures::stream::StreamExt; -use log::{debug, warn}; +use log::{debug, error, warn}; use nym_bin_common::bin_info_owned; use nym_client_core::client::mix_traffic::transceiver::GatewayTransceiver; use nym_client_core::config::disk_persistence::CommonClientPaths; @@ -34,8 +34,7 @@ use nym_sphinx::anonymous_replies::requests::AnonymousSenderTag; use nym_sphinx::params::{PacketSize, PacketType}; use nym_sphinx::receiver::ReconstructedMessage; use nym_task::connections::LaneQueueLengths; -use nym_task::manager::TaskHandle; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use std::path::Path; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -67,7 +66,7 @@ pub struct NRServiceProviderBuilder { wait_for_gateway: bool, custom_topology_provider: Option>, custom_gateway_transceiver: Option>, - shutdown: Option, + shutdown: ShutdownTracker, on_start: Option>, } @@ -79,7 +78,7 @@ pub struct NRServiceProvider { controller_sender: ControllerSender, mix_input_sender: MixProxySender, - shutdown: TaskHandle, + shutdown: ShutdownTracker, } #[async_trait] @@ -148,26 +147,17 @@ impl ServiceProvider for NRServiceProvider { } impl NRServiceProviderBuilder { - pub fn new(config: Config) -> NRServiceProviderBuilder { + pub fn new(config: Config, shutdown: ShutdownTracker) -> NRServiceProviderBuilder { NRServiceProviderBuilder { config, wait_for_gateway: false, custom_topology_provider: None, custom_gateway_transceiver: None, - shutdown: None, + shutdown, on_start: None, } } - #[must_use] - // this is a false positive, this method is actually called when used as a library - // but clippy complains about it when building the binary - #[allow(unused)] - pub fn with_shutdown(mut self, shutdown: TaskClient) -> Self { - self.shutdown = Some(shutdown); - self - } - #[must_use] // this is a false positive, this method is actually called when used as a library // but clippy complains about it when building the binary @@ -233,13 +223,10 @@ impl NRServiceProviderBuilder { /// Start all subsystems pub async fn run_service_provider(self) -> Result<(), NetworkRequesterError> { - // Used to notify tasks to shutdown. Not all tasks fully supports this (yet). - let shutdown: TaskHandle = self.shutdown.map(Into::into).unwrap_or_default(); - // Connect to the mixnet let mixnet_client = create_mixnet_client( &self.config.base, - shutdown.get_handle().named("nym_sdk::MixnetClient[NR]"), + self.shutdown.clone(), self.custom_gateway_transceiver, self.custom_topology_provider, self.wait_for_gateway, @@ -254,9 +241,7 @@ impl NRServiceProviderBuilder { // Controller for managing all active connections. let (mut active_connections_controller, controller_sender) = Controller::new( mixnet_client.connection_command_sender(), - shutdown - .get_handle() - .named("nym_socks5_proxy_helpers::connection_controller::Controller"), + self.shutdown.clone_shutdown_token(), ); tokio::spawn(async move { @@ -285,7 +270,7 @@ impl NRServiceProviderBuilder { mixnet_client, controller_sender, mix_input_sender, - shutdown, + shutdown: self.shutdown, }; log::info!("The address of this client is: {self_address}"); @@ -307,12 +292,13 @@ impl NRServiceProviderBuilder { impl NRServiceProvider { async fn run(&mut self) -> Result<(), NetworkRequesterError> { - let mut shutdown = self.shutdown.fork("main_loop"); - while !shutdown.is_shutdown() { + let shutdown = self.shutdown.clone_shutdown_token(); + loop { tokio::select! { biased; - _ = shutdown.recv() => { - debug!("NRServiceProvider [main loop]: received shutdown") + _ = shutdown.cancelled() => { + debug!("NRServiceProvider [main loop]: received shutdown"); + break }, msg = self.mixnet_client.next() => match msg { Some(msg) => self.on_message(msg).await, @@ -378,7 +364,7 @@ impl NRServiceProvider { controller_sender: ControllerSender, mix_input_sender: MixProxySender, lane_queue_lengths: LaneQueueLengths, - mut shutdown: TaskClient, + shutdown: ShutdownTracker, ) { let mut conn = match socks5::tcp::Connection::new( connection_id, @@ -390,7 +376,6 @@ impl NRServiceProvider { Ok(conn) => conn, Err(err) => { log::error!("error while connecting to {remote_addr}: {err}",); - shutdown.disarm(); // inform the remote that the connection is closed before it even was established let mixnet_message = MixnetMessage::new_network_data_response( @@ -472,11 +457,12 @@ impl NRServiceProvider { let controller_sender_clone = self.controller_sender.clone(); let mix_input_sender_clone = self.mix_input_sender.clone(); let lane_queue_lengths_clone = self.mixnet_client.shared_lane_queue_lengths(); - let mut shutdown = self.shutdown.get_handle(); // we're just cloning the underlying pointer, nothing expensive is happening here let request_filter = self.request_filter.clone(); + let proxy_shutdown = self.shutdown.clone(); + // at this point move it into the separate task // because we might have to resolve the underlying address and it can take some time // during which we don't want to block other incoming requests @@ -491,11 +477,10 @@ impl NRServiceProvider { log_msg, ); - mix_input_sender_clone - .send(error_msg) - .await - .expect("InputMessageReceiver has stopped receiving!"); - shutdown.disarm(); + if mix_input_sender_clone.send(error_msg).await.is_err() { + // don't disarm the shutdown, do cause global shutdown here! + error!("InputMessageReceiver has stopped receiving!"); + } return; } @@ -509,7 +494,7 @@ impl NRServiceProvider { controller_sender_clone, mix_input_sender_clone, lane_queue_lengths_clone, - shutdown, + proxy_shutdown, ) .await }); @@ -563,7 +548,7 @@ impl NRServiceProvider { // TODO: refactor this function and its arguments async fn create_mixnet_client( config: &BaseClientConfig, - shutdown: TaskClient, + shutdown: ShutdownTracker, custom_transceiver: Option>, custom_topology_provider: Option>, wait_for_gateway: bool, diff --git a/service-providers/network-requester/src/socks5/tcp.rs b/service-providers/network-requester/src/socks5/tcp.rs index a1e3eabec2f..63f6f249f32 100644 --- a/service-providers/network-requester/src/socks5/tcp.rs +++ b/service-providers/network-requester/src/socks5/tcp.rs @@ -9,7 +9,7 @@ use nym_socks5_proxy_helpers::proxy_runner::{MixProxySender, ProxyRunner}; use nym_socks5_requests::{ConnectionId, RemoteAddress, Socks5Request}; use nym_sphinx::params::PacketSize; use nym_task::connections::LaneQueueLengths; -use nym_task::TaskClient; +use nym_task::ShutdownTracker; use std::io; use tokio::net::TcpStream; @@ -47,7 +47,7 @@ impl Connection { mix_receiver: ConnectionReceiver, mix_sender: MixProxySender, lane_queue_lengths: LaneQueueLengths, - shutdown: TaskClient, + shutdown: ShutdownTracker, ) { let stream = self.conn.take().unwrap(); let remote_source_address = "???".to_string(); // we don't know ip address of requester diff --git a/wasm/client/src/client.rs b/wasm/client/src/client.rs index e82a8cb78d9..4e70186f73c 100644 --- a/wasm/client/src/client.rs +++ b/wasm/client/src/client.rs @@ -28,7 +28,7 @@ use wasm_client_core::helpers::{ add_gateway, generate_new_client_keys, parse_recipient, parse_sender_tag, }; use wasm_client_core::nym_task::connections::TransmissionLane; -use wasm_client_core::nym_task::TaskManager; +use wasm_client_core::nym_task::ShutdownManager; use wasm_client_core::storage::core_client_traits::FullWasmClientStorage; use wasm_client_core::storage::wasm_client_traits::WasmClientStorage; use wasm_client_core::storage::ClientStorage; @@ -59,7 +59,7 @@ pub struct NymClient { // even though we don't use graceful shutdowns, other components rely on existence of this struct // and if it's dropped, everything will start going offline - _task_manager: TaskManager, + _task_manager: ShutdownManager, packet_type: PacketType, } @@ -249,8 +249,10 @@ impl NymClientBuilder { client_input: Arc::new(client_input), client_state: Arc::new(started_client.client_state), _full_topology: None, - // this cannot failed as we haven't passed an external task manager - _task_manager: started_client.task_handle.try_into_task_manager().unwrap(), + // this cannot fail as we haven't passed an external task manager + _task_manager: started_client + .shutdown_handle + .expect("shutdown manager missing"), packet_type, }) } diff --git a/wasm/client/src/config.rs b/wasm/client/src/config.rs index 8c841c4c167..9bb19ad8c1f 100644 --- a/wasm/client/src/config.rs +++ b/wasm/client/src/config.rs @@ -16,7 +16,7 @@ use wasm_client_core::config::{new_base_client_config, BaseClientConfig, ConfigD pub const DEFAULT_CLIENT_ID: &str = "nym-mixnet-client"; #[wasm_bindgen] -#[derive(Debug, Deserialize, PartialEq, Serialize)] +#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)] #[serde(deny_unknown_fields)] pub struct ClientConfig { pub(crate) base: BaseClientConfig, diff --git a/wasm/mix-fetch/src/client.rs b/wasm/mix-fetch/src/client.rs index 2a2aeb0ab54..e462d5835ee 100644 --- a/wasm/mix-fetch/src/client.rs +++ b/wasm/mix-fetch/src/client.rs @@ -20,7 +20,7 @@ use wasm_client_core::client::base_client::{BaseClientBuilder, ClientInput, Clie use wasm_client_core::client::inbound_messages::InputMessage; use wasm_client_core::helpers::{add_gateway, generate_new_client_keys}; use wasm_client_core::nym_task::connections::TransmissionLane; -use wasm_client_core::nym_task::TaskManager; +use wasm_client_core::nym_task::ShutdownManager; use wasm_client_core::storage::core_client_traits::FullWasmClientStorage; use wasm_client_core::storage::wasm_client_traits::WasmClientStorage; use wasm_client_core::storage::ClientStorage; @@ -41,7 +41,7 @@ pub struct MixFetchClient { requests: ActiveRequests, // this has to be guarded by a mutex to be able to disconnect with an immutable reference - _task_manager: Mutex, + _shutdown_manager: Mutex, } #[wasm_bindgen] @@ -187,8 +187,12 @@ impl MixFetchClientBuilder { self_address, client_input, requests: active_requests, - // this cannot failed as we haven't passed an external task manager - _task_manager: Mutex::new(started_client.task_handle.try_into_task_manager().unwrap()), + // this cannot fail as we haven't passed an external task manager + _shutdown_manager: Mutex::new( + started_client + .shutdown_handle + .expect("shutdown manager missing"), + ), }) } } @@ -229,11 +233,11 @@ impl MixFetchClient { self.invalidated.store(true, Ordering::Relaxed); console_log!("sending shutdown signal"); - let mut shutdown_guard = self._task_manager.lock().await; - shutdown_guard.signal_shutdown().ok(); + let mut shutdown_guard = self._shutdown_manager.lock().await; + shutdown_guard.send_cancellation(); console_log!("waiting for shutdown to complete"); - shutdown_guard.wait_for_shutdown().await; + shutdown_guard.run_until_shutdown().await; self.requests.invalidate_all().await; diff --git a/wasm/node-tester/src/tester.rs b/wasm/node-tester/src/tester.rs index 8ab6b239188..9a48951ea7b 100644 --- a/wasm/node-tester/src/tester.rs +++ b/wasm/node-tester/src/tester.rs @@ -15,7 +15,6 @@ use js_sys::Promise; use nym_node_tester_utils::receiver::SimpleMessageReceiver; use nym_node_tester_utils::tester::LegacyMixLayer; use nym_node_tester_utils::{NodeTester, PacketSize, PreparedFragment}; -use nym_task::TaskManager; use rand::rngs::OsRng; use serde::{Deserialize, Serialize}; use std::collections::HashSet; @@ -32,6 +31,7 @@ use wasm_client_core::client::mix_traffic::transceiver::PacketRouter; use wasm_client_core::helpers::{ current_network_topology_async, setup_from_topology, EphemeralCredentialStorage, }; +use wasm_client_core::nym_task::ShutdownManager; use wasm_client_core::storage::ClientStorage; use wasm_client_core::topology::WasmFriendlyNymTopology; use wasm_client_core::{ @@ -70,7 +70,7 @@ pub struct NymNodeTester { // even though we don't use graceful shutdowns, other components rely on existence of this struct // and if it's dropped, everything will start going offline - _task_manager: TaskManager, + _task_manager: ShutdownManager, } #[wasm_bindgen] @@ -159,7 +159,7 @@ impl NymNodeTesterBuilder { } async fn _setup_client(mut self) -> Result { - let task_manager = TaskManager::default(); + let task_manager = ShutdownManager::new_without_signals(); let storage_id = if let Some(client_id) = &self.id { format!("{NODE_TESTER_ID}-{client_id}") @@ -181,17 +181,13 @@ impl NymNodeTesterBuilder { let (mixnet_message_sender, mixnet_message_receiver) = mpsc::unbounded(); let (ack_sender, ack_receiver) = mpsc::unbounded(); - let gateway_task = task_manager.subscribe().named("gateway_client"); - let packet_router = PacketRouter::new( - ack_sender, - mixnet_message_sender, - gateway_task.fork("packet_router"), - ); + let gateway_task = task_manager.clone_shutdown_token(); + let packet_router = + PacketRouter::new(ack_sender, mixnet_message_sender, gateway_task.clone()); let gateway_identity = gateway_info.gateway_id; - let mut stats_sender_task = task_manager.subscribe().named("stats_sender"); - stats_sender_task.disarm(); + let stats_sender_task = task_manager.clone_shutdown_token(); let mut gateway_client = if let Some(existing_client) = initialisation_result.authenticated_ephemeral_client { @@ -199,7 +195,7 @@ impl NymNodeTesterBuilder { packet_router, self.bandwidth_controller.take(), ClientStatsSender::new(None, stats_sender_task), - gateway_task, + gateway_task.clone(), ) } else { let cfg = GatewayConfig::new( @@ -250,10 +246,10 @@ impl NymNodeTesterBuilder { mixnet_message_receiver, ack_receiver, processed_sender, - task_manager.subscribe(), + task_manager.clone_shutdown_token(), ); - nym_task::spawn(async move { receiver.run().await }); + nym_task::spawn_future(async move { receiver.run().await }); Ok(NymNodeTester { test_in_progress: Arc::new(AtomicBool::new(false)),