diff --git a/libshpool/src/common.rs b/libshpool/src/common.rs index 02321e8f..5ef027da 100644 --- a/libshpool/src/common.rs +++ b/libshpool/src/common.rs @@ -14,10 +14,65 @@ //! The common module is a grab bag of shared utility functions. -use std::env; +use std::{env, thread, time}; use anyhow::anyhow; +/// Controls how often `sleep_unless` re-checks its stop predicate. +#[derive(Clone, Copy, Debug)] +pub enum PollStrategy { + /// Poll at a fixed interval. + Uniform { interval: time::Duration }, + /// Poll with exponential backoff up to a maximum interval. + /// + /// Values <= 1 disable growth and behave like uniform polling. + Backoff { initial_interval: time::Duration, factor: f32, max_interval: time::Duration }, +} + +/// Sleeps for up to `total_sleep`, but returns early if `stop` becomes true. +/// +/// Returns `true` when `stop` triggered before timeout and `false` when the +/// full sleep window elapsed. +pub fn sleep_unless( + total_sleep: time::Duration, + mut stop: F, + poll_strategy: PollStrategy, +) -> bool +where + F: FnMut() -> bool, +{ + let deadline = time::Instant::now() + total_sleep; + let mut next_interval = match poll_strategy { + PollStrategy::Uniform { interval } => interval, + PollStrategy::Backoff { initial_interval, .. } => initial_interval, + }; + + if next_interval.is_zero() { + // Avoid a tight spin-loop if a zero interval is accidentally configured. + next_interval = time::Duration::from_millis(1); + } + + loop { + if stop() { + return true; + } + + let now = time::Instant::now(); + if now >= deadline { + return false; + } + + let remaining = deadline.saturating_duration_since(now); + thread::sleep(remaining.min(next_interval)); + + if let PollStrategy::Backoff { factor, max_interval, .. } = poll_strategy { + if factor > 1.0 { + next_interval = std::cmp::min(next_interval.mul_f32(factor), max_interval); + } + } + } +} + pub fn resolve_sessions(sessions: &mut Vec, action: &str) -> anyhow::Result<()> { if sessions.is_empty() { if let Ok(current_session) = env::var("SHPOOL_SESSION_NAME") { @@ -32,3 +87,54 @@ pub fn resolve_sessions(sessions: &mut Vec, action: &str) -> anyhow::Res Ok(()) } + +#[cfg(test)] +mod tests { + use std::cell::Cell; + use std::time::Duration; + + use super::{sleep_unless, PollStrategy}; + + #[test] + fn sleep_unless_returns_immediately_when_stop_is_true() { + let stopped = sleep_unless( + Duration::from_millis(10), + || true, + PollStrategy::Uniform { interval: Duration::from_millis(1) }, + ); + + assert!(stopped); + } + + #[test] + fn sleep_unless_times_out_when_stop_is_false() { + let stopped = sleep_unless( + Duration::from_millis(3), + || false, + PollStrategy::Uniform { interval: Duration::from_millis(1) }, + ); + + assert!(!stopped); + } + + #[test] + fn sleep_unless_rechecks_stop_with_backoff() { + let checks = Cell::new(0usize); + let stopped = sleep_unless( + Duration::from_millis(20), + || { + let n = checks.get() + 1; + checks.set(n); + n >= 3 + }, + PollStrategy::Backoff { + initial_interval: Duration::from_millis(1), + factor: 2.0, + max_interval: Duration::from_millis(4), + }, + ); + + assert!(stopped); + assert!(checks.get() >= 3); + } +} diff --git a/libshpool/src/consts.rs b/libshpool/src/consts.rs index 3defc60a..d8fd925d 100644 --- a/libshpool/src/consts.rs +++ b/libshpool/src/consts.rs @@ -15,7 +15,7 @@ use std::time; pub const SOCK_STREAM_TIMEOUT: time::Duration = time::Duration::from_millis(200); -pub const JOIN_POLL_DURATION: time::Duration = time::Duration::from_millis(100); +pub const JOIN_POLL_DURATION: time::Duration = time::Duration::from_millis(50); pub const BUF_SIZE: usize = 1024 * 16; diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index 94399535..8ebb8f84 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -32,7 +32,7 @@ use shpool_protocol::{Chunk, ChunkKind, TtySize}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use crate::{ - consts, + common, consts, daemon::{config, exit_notify::ExitNotifier, keybindings, pager::PagerCtl, prompt, show_motd}, protocol::ChunkExt as _, session_restore, test_hooks, @@ -51,10 +51,9 @@ const SUPERVISOR_POLL_DUR: time::Duration = time::Duration::from_millis(300); // size. const REATTACH_RESIZE_DELAY: time::Duration = time::Duration::from_millis(50); -// The shell->client thread should wake up relatively frequently so it can -// detect reattach, but we don't need to go crazy since reattach is not part of -// the inner loop. -const SHELL_TO_CLIENT_POLL_MS: u16 = 100; +// The shell->client thread should poll frequently so detach/reattach control +// messages are noticed quickly without spinning the CPU. +const SHELL_TO_CLIENT_POLL_MS: u16 = 50; // How long to wait before giving up while trying to talk to the // shell->client thread. @@ -835,7 +834,11 @@ impl SessionInner { use keybindings::Action::*; match action { - Detach => self.action_detach()?, + Detach => { + self.action_detach()?; + debug!("exiting client->shell thread after detach"); + return Ok(()); + } NoOp => {} } } @@ -883,12 +886,15 @@ impl SessionInner { loop { trace!("checking stop_rx"); - if stop.load(Ordering::Relaxed) { + let stop_early = common::sleep_unless( + consts::HEARTBEAT_DURATION, + || stop.load(Ordering::Relaxed), + common::PollStrategy::Uniform { interval: consts::JOIN_POLL_DURATION }, + ); + if stop_early { info!("recvd stop msg"); return Ok(()); } - - thread::sleep(consts::HEARTBEAT_DURATION); { let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap(); match shell_to_client_ctl diff --git a/libshpool/src/protocol.rs b/libshpool/src/protocol.rs index f96a444f..6051db8e 100644 --- a/libshpool/src/protocol.rs +++ b/libshpool/src/protocol.rs @@ -27,10 +27,14 @@ use serde::{Deserialize, Serialize}; use shpool_protocol::{Chunk, ChunkKind, ConnectHeader, VersionHeader}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; -use super::{consts, tty}; +use super::{common, consts, tty}; -const JOIN_POLL_DUR: time::Duration = time::Duration::from_millis(100); -const JOIN_HANGUP_DUR: time::Duration = time::Duration::from_millis(300); +const DETACH_DISCONNECT_FAST_WAIT_DUR: time::Duration = time::Duration::from_millis(10); +const MAX_DETACH_WAIT_DUR: time::Duration = time::Duration::from_millis(300); +const DETACH_BACKOFF_INITIAL_DUR: time::Duration = time::Duration::from_millis(1); +// Cap backoff steps so slow-path stays responsive while still avoiding busy +// waits. +const DETACH_BACKOFF_MAX_STEP_DUR: time::Duration = time::Duration::from_millis(25); /// The centralized encoding function that should be used for all protocol /// serialization. @@ -50,7 +54,7 @@ where Ok(()) } -/// The centralized decoding focuntion that should be used for all protocol +/// The centralized decoding function that should be used for all protocol /// deserialization. pub fn decode_from(r: R) -> anyhow::Result where @@ -333,18 +337,49 @@ impl Client { if sock_to_stdout_h.is_finished() { nfinished_threads += 1; } + if nfinished_threads > 0 { if nfinished_threads < 2 { - thread::sleep(JOIN_HANGUP_DUR); - nfinished_threads = 0; - if stdin_to_sock_h.is_finished() { - info!("recheck: stdin->sock thread done"); - nfinished_threads += 1; - } - if sock_to_stdout_h.is_finished() { - info!("recheck: sock->stdout thread done"); - nfinished_threads += 1; + // Fast-path: when sock->stdout already ended (detach/disconnect), + // stdin->sock can stay blocked on stdin. In that case, do a very + // short grace wait and then exit quickly. This is independent + // of stdin being a TTY or a pipe. + // Slow-path: for other shutdown orders, keep compatibility by + // waiting up to 300ms with backoff. + let mut stdin_done = stdin_to_sock_h.is_finished(); + let mut stdout_done = sock_to_stdout_h.is_finished(); + + // Keep max_wait fixed for this detach sequence. Recomputing it inside + // the loop could accidentally switch paths mid-cleanup. + let max_wait = if stdout_done && !stdin_done { + DETACH_DISCONNECT_FAST_WAIT_DUR + } else { + MAX_DETACH_WAIT_DUR + }; + + let finished_waiting = common::sleep_unless( + max_wait, + || { + stdin_done = stdin_to_sock_h.is_finished(); + stdout_done = sock_to_stdout_h.is_finished(); + nfinished_threads = (stdin_done as usize) + (stdout_done as usize); + nfinished_threads >= 2 + }, + common::PollStrategy::Backoff { + initial_interval: DETACH_BACKOFF_INITIAL_DUR, + factor: 2.0, + max_interval: DETACH_BACKOFF_MAX_STEP_DUR, + }, + ); + + if !finished_waiting { + // Re-probe after timeout because thread state can change + // during the final sleep inside sleep_unless. + stdin_done = stdin_to_sock_h.is_finished(); + stdout_done = sock_to_stdout_h.is_finished(); + nfinished_threads = (stdin_done as usize) + (stdout_done as usize); } + if nfinished_threads < 2 { // If one of the worker threads is done and the // other is not exiting, we are likely blocked on @@ -355,8 +390,8 @@ impl Client { // us to use simple blocking IO. warn!( "exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}", - stdin_to_sock_h.is_finished(), - sock_to_stdout_h.is_finished() + stdin_done, + stdout_done ); // make sure that we restore the tty flags on the input // tty before exiting the process. @@ -367,7 +402,8 @@ impl Client { } break; } - thread::sleep(JOIN_POLL_DUR); + + thread::sleep(consts::JOIN_POLL_DURATION); } match stdin_to_sock_h.join() {