Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion libshpool/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,65 @@

//! The common module is a grab bag of shared utility functions.

use std::env;
use std::{env, thread, time};

use anyhow::anyhow;

/// Controls how often `sleep_unless` re-checks its stop predicate.
#[derive(Clone, Copy, Debug)]
pub enum PollStrategy {
/// Poll at a fixed interval.
Uniform { interval: time::Duration },
/// Poll with exponential backoff up to a maximum interval.
///
/// Values <= 1 disable growth and behave like uniform polling.
Backoff { initial_interval: time::Duration, factor: f32, max_interval: time::Duration },
}

/// Sleeps for up to `total_sleep`, but returns early if `stop` becomes true.
///
/// Returns `true` when `stop` triggered before timeout and `false` when the
/// full sleep window elapsed.
pub fn sleep_unless<F>(
total_sleep: time::Duration,
mut stop: F,
poll_strategy: PollStrategy,
) -> bool
where
F: FnMut() -> bool,
{
let deadline = time::Instant::now() + total_sleep;
let mut next_interval = match poll_strategy {
PollStrategy::Uniform { interval } => interval,
PollStrategy::Backoff { initial_interval, .. } => initial_interval,
};

if next_interval.is_zero() {
// Avoid a tight spin-loop if a zero interval is accidentally configured.
next_interval = time::Duration::from_millis(1);
}

loop {
if stop() {
return true;
}

let now = time::Instant::now();
if now >= deadline {
return false;
}

let remaining = deadline.saturating_duration_since(now);
thread::sleep(remaining.min(next_interval));

if let PollStrategy::Backoff { factor, max_interval, .. } = poll_strategy {
if factor > 1.0 {
next_interval = std::cmp::min(next_interval.mul_f32(factor), max_interval);
}
}
}
}

pub fn resolve_sessions(sessions: &mut Vec<String>, action: &str) -> anyhow::Result<()> {
if sessions.is_empty() {
if let Ok(current_session) = env::var("SHPOOL_SESSION_NAME") {
Expand All @@ -32,3 +87,54 @@ pub fn resolve_sessions(sessions: &mut Vec<String>, action: &str) -> anyhow::Res

Ok(())
}

#[cfg(test)]
mod tests {
use std::cell::Cell;
use std::time::Duration;

use super::{sleep_unless, PollStrategy};

#[test]
fn sleep_unless_returns_immediately_when_stop_is_true() {
let stopped = sleep_unless(
Duration::from_millis(10),
|| true,
PollStrategy::Uniform { interval: Duration::from_millis(1) },
);

assert!(stopped);
}

#[test]
fn sleep_unless_times_out_when_stop_is_false() {
let stopped = sleep_unless(
Duration::from_millis(3),
|| false,
PollStrategy::Uniform { interval: Duration::from_millis(1) },
);

assert!(!stopped);
}

#[test]
fn sleep_unless_rechecks_stop_with_backoff() {
let checks = Cell::new(0usize);
let stopped = sleep_unless(
Duration::from_millis(20),
|| {
let n = checks.get() + 1;
checks.set(n);
n >= 3
},
PollStrategy::Backoff {
initial_interval: Duration::from_millis(1),
factor: 2.0,
max_interval: Duration::from_millis(4),
},
);

assert!(stopped);
assert!(checks.get() >= 3);
}
}
2 changes: 1 addition & 1 deletion libshpool/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::time;

pub const SOCK_STREAM_TIMEOUT: time::Duration = time::Duration::from_millis(200);
pub const JOIN_POLL_DURATION: time::Duration = time::Duration::from_millis(100);
pub const JOIN_POLL_DURATION: time::Duration = time::Duration::from_millis(50);

pub const BUF_SIZE: usize = 1024 * 16;

Expand Down
24 changes: 15 additions & 9 deletions libshpool/src/daemon/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use shpool_protocol::{Chunk, ChunkKind, TtySize};
use tracing::{debug, error, info, instrument, span, trace, warn, Level};

use crate::{
consts,
common, consts,
daemon::{config, exit_notify::ExitNotifier, keybindings, pager::PagerCtl, prompt, show_motd},
protocol::ChunkExt as _,
session_restore, test_hooks,
Expand All @@ -51,10 +51,9 @@ const SUPERVISOR_POLL_DUR: time::Duration = time::Duration::from_millis(300);
// size.
const REATTACH_RESIZE_DELAY: time::Duration = time::Duration::from_millis(50);

// The shell->client thread should wake up relatively frequently so it can
// detect reattach, but we don't need to go crazy since reattach is not part of
// the inner loop.
const SHELL_TO_CLIENT_POLL_MS: u16 = 100;
// The shell->client thread should poll frequently so detach/reattach control
// messages are noticed quickly without spinning the CPU.
const SHELL_TO_CLIENT_POLL_MS: u16 = 50;

// How long to wait before giving up while trying to talk to the
// shell->client thread.
Expand Down Expand Up @@ -835,7 +834,11 @@ impl SessionInner {

use keybindings::Action::*;
match action {
Detach => self.action_detach()?,
Detach => {
self.action_detach()?;
debug!("exiting client->shell thread after detach");
return Ok(());
}
NoOp => {}
}
}
Expand Down Expand Up @@ -883,12 +886,15 @@ impl SessionInner {

loop {
trace!("checking stop_rx");
if stop.load(Ordering::Relaxed) {
let stop_early = common::sleep_unless(
consts::HEARTBEAT_DURATION,
|| stop.load(Ordering::Relaxed),
common::PollStrategy::Uniform { interval: consts::JOIN_POLL_DURATION },
);
if stop_early {
info!("recvd stop msg");
return Ok(());
}

thread::sleep(consts::HEARTBEAT_DURATION);
{
let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap();
match shell_to_client_ctl
Expand Down
68 changes: 52 additions & 16 deletions libshpool/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ use serde::{Deserialize, Serialize};
use shpool_protocol::{Chunk, ChunkKind, ConnectHeader, VersionHeader};
use tracing::{debug, error, info, instrument, span, trace, warn, Level};

use super::{consts, tty};
use super::{common, consts, tty};

const JOIN_POLL_DUR: time::Duration = time::Duration::from_millis(100);
const JOIN_HANGUP_DUR: time::Duration = time::Duration::from_millis(300);
const DETACH_DISCONNECT_FAST_WAIT_DUR: time::Duration = time::Duration::from_millis(10);
const MAX_DETACH_WAIT_DUR: time::Duration = time::Duration::from_millis(300);
const DETACH_BACKOFF_INITIAL_DUR: time::Duration = time::Duration::from_millis(1);
// Cap backoff steps so slow-path stays responsive while still avoiding busy
// waits.
const DETACH_BACKOFF_MAX_STEP_DUR: time::Duration = time::Duration::from_millis(25);

/// The centralized encoding function that should be used for all protocol
/// serialization.
Expand All @@ -50,7 +54,7 @@ where
Ok(())
}

/// The centralized decoding focuntion that should be used for all protocol
/// The centralized decoding function that should be used for all protocol
/// deserialization.
pub fn decode_from<T, R>(r: R) -> anyhow::Result<T>
where
Expand Down Expand Up @@ -333,18 +337,49 @@ impl Client {
if sock_to_stdout_h.is_finished() {
nfinished_threads += 1;
}

if nfinished_threads > 0 {
if nfinished_threads < 2 {
thread::sleep(JOIN_HANGUP_DUR);
nfinished_threads = 0;
if stdin_to_sock_h.is_finished() {
info!("recheck: stdin->sock thread done");
nfinished_threads += 1;
}
if sock_to_stdout_h.is_finished() {
info!("recheck: sock->stdout thread done");
nfinished_threads += 1;
// Fast-path: when sock->stdout already ended (detach/disconnect),
// stdin->sock can stay blocked on stdin. In that case, do a very
// short grace wait and then exit quickly. This is independent
// of stdin being a TTY or a pipe.
// Slow-path: for other shutdown orders, keep compatibility by
// waiting up to 300ms with backoff.
let mut stdin_done = stdin_to_sock_h.is_finished();
let mut stdout_done = sock_to_stdout_h.is_finished();

// Keep max_wait fixed for this detach sequence. Recomputing it inside
// the loop could accidentally switch paths mid-cleanup.
let max_wait = if stdout_done && !stdin_done {
DETACH_DISCONNECT_FAST_WAIT_DUR
} else {
MAX_DETACH_WAIT_DUR
};

let finished_waiting = common::sleep_unless(
max_wait,
|| {
stdin_done = stdin_to_sock_h.is_finished();
stdout_done = sock_to_stdout_h.is_finished();
nfinished_threads = (stdin_done as usize) + (stdout_done as usize);
nfinished_threads >= 2
},
common::PollStrategy::Backoff {
initial_interval: DETACH_BACKOFF_INITIAL_DUR,
factor: 2.0,
max_interval: DETACH_BACKOFF_MAX_STEP_DUR,
},
);

if !finished_waiting {
// Re-probe after timeout because thread state can change
// during the final sleep inside sleep_unless.
stdin_done = stdin_to_sock_h.is_finished();
stdout_done = sock_to_stdout_h.is_finished();
nfinished_threads = (stdin_done as usize) + (stdout_done as usize);
}

if nfinished_threads < 2 {
// If one of the worker threads is done and the
// other is not exiting, we are likely blocked on
Expand All @@ -355,8 +390,8 @@ impl Client {
// us to use simple blocking IO.
warn!(
"exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}",
stdin_to_sock_h.is_finished(),
sock_to_stdout_h.is_finished()
stdin_done,
stdout_done
);
// make sure that we restore the tty flags on the input
// tty before exiting the process.
Expand All @@ -367,7 +402,8 @@ impl Client {
}
break;
}
thread::sleep(JOIN_POLL_DUR);

thread::sleep(consts::JOIN_POLL_DURATION);
}

match stdin_to_sock_h.join() {
Expand Down
Loading