diff --git a/common/task/src/cancellation.rs b/common/task/src/cancellation.rs index 9e120b4ec55..35009cc18ef 100644 --- a/common/task/src/cancellation.rs +++ b/common/task/src/cancellation.rs @@ -5,6 +5,7 @@ use crate::{TaskClient, TaskManager}; use futures::stream::FuturesUnordered; use futures::StreamExt; use std::future::Future; +use std::mem; use std::ops::Deref; use std::pin::Pin; use std::time::Duration; @@ -185,12 +186,21 @@ impl ShutdownDropGuard { } } +#[derive(Default)] +pub struct ShutdownSignals(JoinSet<()>); + +impl ShutdownSignals { + pub async fn wait_for_signal(&mut self) { + self.0.join_next().await; + } +} + pub struct ShutdownManager { pub root_token: ShutdownToken, legacy_task_manager: Option, - shutdown_signals: JoinSet<()>, + shutdown_signals: ShutdownSignals, // the reason I'm not using a `JoinSet` is because it forces us to use futures with the same `::Output` type tracker: TaskTracker, @@ -261,7 +271,7 @@ impl ShutdownManager { F: Send + 'static, { let shutdown_token = self.root_token.clone(); - self.shutdown_signals.spawn(async move { + self.shutdown_signals.0.spawn(async move { shutdown.await; info!("sending cancellation after receiving shutdown signal"); @@ -356,9 +366,20 @@ impl ShutdownManager { wait_futures.next().await; } - pub async fn wait_for_shutdown_signal(mut self) { - self.shutdown_signals.join_next().await; + pub fn detach_shutdown_signals(&mut self) -> ShutdownSignals { + mem::take(&mut self.shutdown_signals) + } + + pub fn replace_shutdown_signals(&mut self, signals: ShutdownSignals) { + self.shutdown_signals = signals; + } + + // cancellation safe + pub async fn wait_for_shutdown_signal(&mut self) { + self.shutdown_signals.0.join_next().await; + } + pub async fn perform_shutdown(mut self) { if let Some(legacy_manager) = self.legacy_task_manager.as_mut() { info!("attempting to shutdown legacy tasks"); let _ = legacy_manager.signal_shutdown(); @@ -367,4 +388,10 @@ impl ShutdownManager { info!("waiting for tasks to finish... (press ctrl-c to force)"); self.finish_shutdown().await; } + + pub async fn run_until_shutdown(mut self) { + self.wait_for_shutdown_signal().await; + + self.perform_shutdown().await; + } } diff --git a/gateway/src/node/internal_service_providers.rs b/gateway/src/node/internal_service_providers.rs index b26d9f562f8..efa6d28e43d 100644 --- a/gateway/src/node/internal_service_providers.rs +++ b/gateway/src/node/internal_service_providers.rs @@ -128,6 +128,8 @@ where } }); + // TODO: if something is blocking during SP startup, the below will wait forever + // we need to introduce additional timeouts here. let on_start_data = self .on_start_rx .await diff --git a/nym-node/src/node/mod.rs b/nym-node/src/node/mod.rs index 6349b564135..f3041b75b61 100644 --- a/nym-node/src/node/mod.rs +++ b/nym-node/src/node/mod.rs @@ -1114,12 +1114,12 @@ impl NymNode { .await?; self.shutdown_manager.close(); - self.shutdown_manager.wait_for_shutdown_signal().await; + self.shutdown_manager.run_until_shutdown().await; Ok(()) } - pub(crate) async fn run(mut self) -> Result<(), NymNodeError> { + async fn start_nym_node_tasks(mut self) -> Result { info!("starting Nym Node {} with the following modes: mixnode: {}, entry: {}, exit: {}, wireguard: {}", self.ed25519_identity_key(), self.config.modes.mixnode, @@ -1189,9 +1189,27 @@ impl NymNode { .await?; network_refresher.start(); - self.shutdown_manager.close(); - self.shutdown_manager.wait_for_shutdown_signal().await; + + Ok(self.shutdown_manager) + } + + pub(crate) async fn run(mut self) -> Result<(), NymNodeError> { + let mut shutdown_signals = self.shutdown_manager.detach_shutdown_signals(); + + // listen for shutdown signal in case we received it when attempting to spawn all the tasks + tokio::select! { + _ = shutdown_signals.wait_for_signal() => { + info!("received shutdown signal during setup - exiting"); + // ideally we'd also do some cleanup here, but currently there's no easy way to access the handles + return Ok(()) + } + startup_result = self.start_nym_node_tasks() => { + let mut shutdown_manager = startup_result?; + shutdown_manager.replace_shutdown_signals(shutdown_signals); + shutdown_manager.run_until_shutdown().await; + } + } Ok(()) } diff --git a/nym-statistics-api/src/main.rs b/nym-statistics-api/src/main.rs index 4175bbfd071..4af7c83a363 100644 --- a/nym-statistics-api/src/main.rs +++ b/nym-statistics-api/src/main.rs @@ -48,7 +48,7 @@ async fn main() -> anyhow::Result<()> { tracing::info!("Started HTTP server on port {}", args.http_port); shutdown_manager.close(); - shutdown_manager.wait_for_shutdown_signal().await; + shutdown_manager.run_until_shutdown().await; Ok(()) }