Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions libshpool/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,64 @@
//! The common module is a grab bag of shared utility functions.

use std::env;
use std::{thread, time};

use anyhow::anyhow;

/// Controls how often `sleep_unless` re-checks its stop predicate.
#[derive(Clone, Copy, Debug)]
pub enum PollStrategy {
/// Poll at a fixed interval.
Uniform { interval: time::Duration },
/// Poll with exponential backoff up to a maximum interval.
Backoff { initial_interval: time::Duration, factor: u32, max_interval: time::Duration },
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a float for the factor. The most common backoff factors are 2 and 1.5, and if we use an int we can't handle 1.5.

}

/// Sleeps for up to `total_sleep`, but returns early if `stop` becomes true.
///
/// Returns `true` when `stop` triggered before timeout and `false` when the
/// full sleep window elapsed.
pub fn sleep_unless<F>(
total_sleep: time::Duration,
mut stop: F,
poll_strategy: PollStrategy,
) -> bool
where
F: FnMut() -> bool,
{
let deadline = time::Instant::now() + total_sleep;
let mut next_interval = match poll_strategy {
PollStrategy::Uniform { interval } => interval,
PollStrategy::Backoff { initial_interval, .. } => initial_interval,
};

if next_interval.is_zero() {
// Avoid a tight spin-loop if a zero interval is accidentally configured.
next_interval = time::Duration::from_millis(1);
}

loop {
if stop() {
return true;
}

let now = time::Instant::now();
if now >= deadline {
return false;
}

let remaining = deadline.saturating_duration_since(now);
thread::sleep(remaining.min(next_interval));

if let PollStrategy::Backoff { factor, max_interval, .. } = poll_strategy {
if factor > 1 {
let grown = next_interval.checked_mul(factor).unwrap_or(max_interval);
next_interval = grown.min(max_interval);
}
}
}
}

pub fn resolve_sessions(sessions: &mut Vec<String>, action: &str) -> anyhow::Result<()> {
if sessions.is_empty() {
if let Ok(current_session) = env::var("SHPOOL_SESSION_NAME") {
Expand All @@ -32,3 +87,54 @@ pub fn resolve_sessions(sessions: &mut Vec<String>, action: &str) -> anyhow::Res

Ok(())
}

#[cfg(test)]
mod tests {
use std::cell::Cell;
use std::time::Duration;

use super::{sleep_unless, PollStrategy};

#[test]
fn sleep_unless_returns_immediately_when_stop_is_true() {
let stopped = sleep_unless(
Duration::from_millis(10),
|| true,
PollStrategy::Uniform { interval: Duration::from_millis(1) },
);

assert!(stopped);
}

#[test]
fn sleep_unless_times_out_when_stop_is_false() {
let stopped = sleep_unless(
Duration::from_millis(3),
|| false,
PollStrategy::Uniform { interval: Duration::from_millis(1) },
);

assert!(!stopped);
}

#[test]
fn sleep_unless_rechecks_stop_with_backoff() {
let checks = Cell::new(0usize);
let stopped = sleep_unless(
Duration::from_millis(20),
|| {
let n = checks.get() + 1;
checks.set(n);
n >= 3
},
PollStrategy::Backoff {
initial_interval: Duration::from_millis(1),
factor: 2,
max_interval: Duration::from_millis(4),
},
);

assert!(stopped);
assert!(checks.get() >= 3);
}
}
2 changes: 1 addition & 1 deletion libshpool/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::time;

pub const SOCK_STREAM_TIMEOUT: time::Duration = time::Duration::from_millis(200);
pub const JOIN_POLL_DURATION: time::Duration = time::Duration::from_millis(100);
pub const JOIN_POLL_DURATION: time::Duration = time::Duration::from_millis(50);

pub const BUF_SIZE: usize = 1024 * 16;

Expand Down
24 changes: 15 additions & 9 deletions libshpool/src/daemon/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use shpool_protocol::{Chunk, ChunkKind, TtySize};
use tracing::{debug, error, info, instrument, span, trace, warn, Level};

use crate::{
consts,
common, consts,
daemon::{config, exit_notify::ExitNotifier, keybindings, pager::PagerCtl, prompt, show_motd},
protocol::ChunkExt as _,
session_restore, test_hooks,
Expand All @@ -51,10 +51,9 @@ const SUPERVISOR_POLL_DUR: time::Duration = time::Duration::from_millis(300);
// size.
const REATTACH_RESIZE_DELAY: time::Duration = time::Duration::from_millis(50);

// The shell->client thread should wake up relatively frequently so it can
// detect reattach, but we don't need to go crazy since reattach is not part of
// the inner loop.
const SHELL_TO_CLIENT_POLL_MS: u16 = 100;
// The shell->client thread should poll frequently so detach/reattach control
// messages are noticed quickly without spinning the CPU.
const SHELL_TO_CLIENT_POLL_MS: u16 = 50;

// How long to wait before giving up while trying to talk to the
// shell->client thread.
Expand Down Expand Up @@ -835,7 +834,11 @@ impl SessionInner {

use keybindings::Action::*;
match action {
Detach => self.action_detach()?,
Detach => {
self.action_detach()?;
debug!("exiting client->shell thread after detach");
return Ok(());
}
NoOp => {}
}
}
Expand Down Expand Up @@ -883,12 +886,15 @@ impl SessionInner {

loop {
trace!("checking stop_rx");
if stop.load(Ordering::Relaxed) {
let stop_early = common::sleep_unless(
consts::HEARTBEAT_DURATION,
|| stop.load(Ordering::Relaxed),
common::PollStrategy::Uniform { interval: consts::JOIN_POLL_DURATION },
);
if stop_early {
info!("recvd stop msg");
return Ok(());
}

thread::sleep(consts::HEARTBEAT_DURATION);
{
let shell_to_client_ctl = self.shell_to_client_ctl.lock().unwrap();
match shell_to_client_ctl
Expand Down
67 changes: 52 additions & 15 deletions libshpool/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@ use std::{

use anyhow::{anyhow, Context};
use byteorder::{LittleEndian, ReadBytesExt as _, WriteBytesExt as _};
use nix::unistd::isatty;
use serde::{Deserialize, Serialize};
use shpool_protocol::{Chunk, ChunkKind, ConnectHeader, VersionHeader};
use tracing::{debug, error, info, instrument, span, trace, warn, Level};

use super::{consts, tty};

const JOIN_POLL_DUR: time::Duration = time::Duration::from_millis(100);
const JOIN_HANGUP_DUR: time::Duration = time::Duration::from_millis(300);
const DETACH_TTY_FAST_WAIT_DUR: time::Duration = time::Duration::from_millis(10);
const MAX_DETACH_WAIT_DUR: time::Duration = time::Duration::from_millis(300);
const DETACH_BACKOFF_INITIAL_DUR: time::Duration = time::Duration::from_millis(1);
// Cap backoff steps so slow-path stays responsive while still avoiding busy
// waits.
const DETACH_BACKOFF_MAX_STEP_DUR: time::Duration = time::Duration::from_millis(25);

/// The centralized encoding function that should be used for all protocol
/// serialization.
Expand All @@ -50,7 +55,7 @@ where
Ok(())
}

/// The centralized decoding focuntion that should be used for all protocol
/// The centralized decoding function that should be used for all protocol
/// deserialization.
pub fn decode_from<T, R>(r: R) -> anyhow::Result<T>
where
Expand Down Expand Up @@ -333,18 +338,49 @@ impl Client {
if sock_to_stdout_h.is_finished() {
nfinished_threads += 1;
}

if nfinished_threads > 0 {
if nfinished_threads < 2 {
thread::sleep(JOIN_HANGUP_DUR);
nfinished_threads = 0;
if stdin_to_sock_h.is_finished() {
info!("recheck: stdin->sock thread done");
nfinished_threads += 1;
}
if sock_to_stdout_h.is_finished() {
info!("recheck: sock->stdout thread done");
nfinished_threads += 1;
// Fast-path: when server->client already ended (detach/disconnect),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep the thread names consistant and call this sock->stdout

// stdin->sock can stay blocked on stdin. In that case, do a very
// short grace wait and then exit quickly.
// Slow-path: for other shutdown orders, keep compatibility by
// waiting up to 300ms with backoff.
let mut total_wait = time::Duration::ZERO;
let mut next_sleep = DETACH_BACKOFF_INITIAL_DUR;
let mut stdin_done = stdin_to_sock_h.is_finished();
let mut stdout_done = sock_to_stdout_h.is_finished();

let stdin_is_tty = isatty(io::stdin()).unwrap_or(false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally like to avoid isatty when possible since it makes it harder to predict how a tool will work when running under a script. Sometimes it is worthwhile, but in this case I don't think it is worth having divergant behavior. I don't see a reason we would need to wait around longer in a script context, so let's just always use the short timeout if the daemon hangs up on us. We should re-name the constant to avoid mentioning TTY when we do this.

// Keep max_wait fixed for this detach sequence. Recomputing it inside
// the loop could accidentally switch paths mid-cleanup.
let max_wait = if stdin_is_tty && stdout_done && !stdin_done {
DETACH_TTY_FAST_WAIT_DUR
} else {
MAX_DETACH_WAIT_DUR
};

loop {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop can switch to using the helper I suggested above, with its stop predicate computing nfinished_threads and checking if the count is >= 2.

stdin_done = stdin_to_sock_h.is_finished();
stdout_done = sock_to_stdout_h.is_finished();
nfinished_threads = (stdin_done as usize) + (stdout_done as usize);

if nfinished_threads >= 2 {
break;
}
if total_wait >= max_wait {
break;
}

let remaining = max_wait - total_wait;
let sleep_for = cmp::min(next_sleep, remaining);
thread::sleep(sleep_for);
total_wait += sleep_for;
// Exponential backoff with a capped step to avoid busy waits.
next_sleep =
cmp::min(next_sleep + next_sleep, DETACH_BACKOFF_MAX_STEP_DUR);
}

if nfinished_threads < 2 {
// If one of the worker threads is done and the
// other is not exiting, we are likely blocked on
Expand All @@ -355,8 +391,8 @@ impl Client {
// us to use simple blocking IO.
warn!(
"exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}",
stdin_to_sock_h.is_finished(),
sock_to_stdout_h.is_finished()
stdin_done,
stdout_done
);
// make sure that we restore the tty flags on the input
// tty before exiting the process.
Expand All @@ -367,7 +403,8 @@ impl Client {
}
break;
}
thread::sleep(JOIN_POLL_DUR);

thread::sleep(consts::JOIN_POLL_DURATION);
}

match stdin_to_sock_h.join() {
Expand Down