Skip to content

Commit

Permalink
task: drop the join waker of a task eagerly (#6986)
Browse files Browse the repository at this point in the history
  • Loading branch information
tglane authored Dec 29, 2024
1 parent 4ca13e6 commit 970d880
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 20 deletions.
1 change: 1 addition & 0 deletions spellcheck.dic
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ unparks
Unparks
unreceived
unsafety
unsets
Unsets
unsynchronized
untrusted
Expand Down
41 changes: 37 additions & 4 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,11 @@ where
}

pub(super) fn drop_join_handle_slow(self) {
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// Try to unset `JOIN_INTEREST` and `JOIN_WAKER`. This must be done as a first step in
// case the task concurrently completed.
if self.state().unset_join_interested().is_err() {
let transition = self.state().transition_to_join_handle_dropped();

if transition.drop_output {
// It is our responsibility to drop the output. This is critical as
// the task output may not be `Send` and as such must remain with
// the scheduler or `JoinHandle`. i.e. if the output remains in the
Expand All @@ -301,6 +303,23 @@ where
}));
}

if transition.drop_waker {
// If the JOIN_WAKER flag is unset at this point, the task is either
// already terminal or not complete so the `JoinHandle` is responsible
// for dropping the waker.
// Safety:
// If the JOIN_WAKER bit is not set the join handle has exclusive
// access to the waker as per rule 2 in task/mod.rs.
// This can only be the case at this point in two scenarios:
// 1. The task completed and the runtime unset `JOIN_WAKER` flag
// after accessing the waker during task completion. So the
// `JoinHandle` is the only one to access the join waker here.
// 2. The task is not completed so the `JoinHandle` was able to unset
// `JOIN_WAKER` bit itself to get mutable access to the waker.
// The runtime will not access the waker when this flag is unset.
unsafe { self.trailer().set_waker(None) };
}

// Drop the `JoinHandle` reference, possibly deallocating the task
self.drop_reference();
}
Expand All @@ -311,7 +330,6 @@ where
fn complete(self) {
// The future has completed and its output has been written to the task
// stage. We transition from running to complete.

let snapshot = self.state().transition_to_complete();

// We catch panics here in case dropping the future or waking the
Expand All @@ -320,13 +338,28 @@ where
if !snapshot.is_join_interested() {
// The `JoinHandle` is not interested in the output of
// this task. It is our responsibility to drop the
// output.
// output. The join waker was already dropped by the
// `JoinHandle` before.
self.core().drop_future_or_output();
} else if snapshot.is_join_waker_set() {
// Notify the waker. Reading the waker field is safe per rule 4
// in task/mod.rs, since the JOIN_WAKER bit is set and the call
// to transition_to_complete() above set the COMPLETE bit.
self.trailer().wake_join();

// Inform the `JoinHandle` that we are done waking the waker by
// unsetting the `JOIN_WAKER` bit. If the `JoinHandle` has
// already been dropped and `JOIN_INTEREST` is unset, then we must
// drop the waker ourselves.
if !self
.state()
.unset_waker_after_complete()
.is_join_interested()
{
// SAFETY: We have COMPLETE=1 and JOIN_INTEREST=0, so
// we have exclusive access to the waker.
unsafe { self.trailer().set_waker(None) };
}
}
}));

Expand Down
18 changes: 16 additions & 2 deletions tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,30 @@
//! `JoinHandle` needs to (i) successfully set `JOIN_WAKER` to zero if it is
//! not already zero to gain exclusive access to the waker field per rule
//! 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` to one.
//! If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped
//! to clear the waker field, only steps (i) and (ii) are relevant.
//!
//! 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e.
//! the task hasn't yet completed).
//! the task hasn't yet completed). The runtime can change `JOIN_WAKER` only
//! if COMPLETE is one.
//!
//! 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the runtime has
//! exclusive (mutable) access to the waker field. This might happen if the
//! `JoinHandle` gets dropped right after the task completes and the runtime
//! sets the `COMPLETE` bit. In this case the runtime needs the mutable access
//! to the waker field to drop it.
//!
//! Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a
//! race. If step (i) fails, then the attempt to write a waker is aborted. If
//! step (iii) fails because COMPLETE is set to one by another thread after
//! step (i), then the waker field is cleared. Once COMPLETE is one (i.e.
//! task has completed), the `JoinHandle` will not modify `JOIN_WAKER`. After the
//! runtime sets COMPLETE to one, it invokes the waker if there is one.
//! runtime sets COMPLETE to one, it invokes the waker if there is one so in this
//! case when a task completes the `JOIN_WAKER` bit implicates to the runtime
//! whether it should invoke the waker or not. After the runtime is done with
//! using the waker during task completion, it unsets the `JOIN_WAKER` bit to give
//! the `JoinHandle` exclusive access again so that it is able to drop the waker
//! at a later point.
//!
//! All other fields are immutable and can be accessed immutably without
//! synchronization by anyone.
Expand Down
63 changes: 51 additions & 12 deletions tokio/src/runtime/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ pub(crate) enum TransitionToNotifiedByRef {
Submit,
}

#[must_use]
pub(super) struct TransitionToJoinHandleDrop {
pub(super) drop_waker: bool,
pub(super) drop_output: bool,
}

/// All transitions are performed via RMW operations. This establishes an
/// unambiguous modification order.
impl State {
Expand Down Expand Up @@ -371,22 +377,45 @@ impl State {
.map_err(|_| ())
}

/// Tries to unset the `JOIN_INTEREST` flag.
///
/// Returns `Ok` if the operation happens before the task transitions to a
/// completed state, `Err` otherwise.
pub(super) fn unset_join_interested(&self) -> UpdateResult {
self.fetch_update(|curr| {
assert!(curr.is_join_interested());
/// Unsets the `JOIN_INTEREST` flag. If `COMPLETE` is not set, the `JOIN_WAKER`
/// flag is also unset.
/// The returned `TransitionToJoinHandleDrop` indicates whether the `JoinHandle` should drop
/// the output of the future or the join waker after the transition.
pub(super) fn transition_to_join_handle_dropped(&self) -> TransitionToJoinHandleDrop {
self.fetch_update_action(|mut snapshot| {
assert!(snapshot.is_join_interested());

if curr.is_complete() {
return None;
let mut transition = TransitionToJoinHandleDrop {
drop_waker: false,
drop_output: false,
};

snapshot.unset_join_interested();

if !snapshot.is_complete() {
// If `COMPLETE` is unset we also unset `JOIN_WAKER` to give the
// `JoinHandle` exclusive access to the waker following rule 6 in task/mod.rs.
// The `JoinHandle` will drop the waker if it has exclusive access
// to drop it.
snapshot.unset_join_waker();
} else {
// If `COMPLETE` is set the task is completed so the `JoinHandle` is responsible
// for dropping the output.
transition.drop_output = true;
}

let mut next = curr;
next.unset_join_interested();
if !snapshot.is_join_waker_set() {
// If the `JOIN_WAKER` bit is unset and the `JOIN_HANDLE` has exclusive access to
// the join waker and should drop it following this transition.
// This might happen in two situations:
// 1. The task is not completed and we just unset the `JOIN_WAKer` above in this
// function.
// 2. The task is completed. In that case the `JOIN_WAKER` bit was already unset
// by the runtime during completion.
transition.drop_waker = true;
}

Some(next)
(transition, Some(snapshot))
})
}

Expand Down Expand Up @@ -430,6 +459,16 @@ impl State {
})
}

/// Unsets the `JOIN_WAKER` bit unconditionally after task completion.
///
/// This operation requires the task to be completed.
pub(super) fn unset_waker_after_complete(&self) -> Snapshot {
let prev = Snapshot(self.val.fetch_and(!JOIN_WAKER, AcqRel));
assert!(prev.is_complete());
assert!(prev.is_join_waker_set());
Snapshot(prev.0 & !JOIN_WAKER)
}

pub(super) fn ref_inc(&self) {
use std::process;
use std::sync::atomic::Ordering::Relaxed;
Expand Down
58 changes: 56 additions & 2 deletions tokio/src/runtime/tests/loom_current_thread.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod yield_now;

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::{AtomicUsize, Ordering};
use crate::loom::sync::Arc;
use crate::loom::thread;
use crate::runtime::{Builder, Runtime};
Expand All @@ -9,7 +9,7 @@ use crate::task;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::Ordering::{Acquire, Release};
use std::task::{Context, Poll};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};

fn assert_at_most_num_polls(rt: Arc<Runtime>, at_most_polls: usize) {
let (tx, rx) = oneshot::channel();
Expand Down Expand Up @@ -106,6 +106,60 @@ fn assert_no_unnecessary_polls() {
});
}

#[test]
fn drop_jh_during_schedule() {
unsafe fn waker_clone(ptr: *const ()) -> RawWaker {
let atomic = unsafe { &*(ptr as *const AtomicUsize) };
atomic.fetch_add(1, Ordering::Relaxed);
RawWaker::new(ptr, &VTABLE)
}
unsafe fn waker_drop(ptr: *const ()) {
let atomic = unsafe { &*(ptr as *const AtomicUsize) };
atomic.fetch_sub(1, Ordering::Relaxed);
}
unsafe fn waker_nop(_ptr: *const ()) {}

static VTABLE: RawWakerVTable =
RawWakerVTable::new(waker_clone, waker_drop, waker_nop, waker_drop);

loom::model(|| {
let rt = Builder::new_current_thread().build().unwrap();

let mut jh = rt.spawn(async {});
// Using AbortHandle to increment task refcount. This ensures that the waker is not
// destroyed due to the refcount hitting zero.
let task_refcnt = jh.abort_handle();

let waker_refcnt = AtomicUsize::new(1);
{
// Set up the join waker.
use std::future::Future;
use std::pin::Pin;

// SAFETY: Before `waker_refcnt` goes out of scope, this test asserts that the refcnt
// has dropped to zero.
let join_waker = unsafe {
Waker::from_raw(RawWaker::new(
(&waker_refcnt) as *const AtomicUsize as *const (),
&VTABLE,
))
};

assert!(Pin::new(&mut jh)
.poll(&mut Context::from_waker(&join_waker))
.is_pending());
}
assert_eq!(waker_refcnt.load(Ordering::Relaxed), 1);

let bg_thread = loom::thread::spawn(move || drop(jh));
rt.block_on(crate::task::yield_now());
bg_thread.join().unwrap();

assert_eq!(waker_refcnt.load(Ordering::Relaxed), 0);
drop(task_refcnt);
});
}

struct BlockedFuture {
rx: Receiver<()>,
num_polls: Arc<AtomicUsize>,
Expand Down
36 changes: 36 additions & 0 deletions tokio/tests/rt_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::sync::Arc;
use tokio::runtime::Runtime;
use tokio::sync::{mpsc, Barrier};

#[test]
#[cfg_attr(panic = "abort", ignore)]
Expand Down Expand Up @@ -65,6 +67,40 @@ fn interleave_then_enter() {
let _enter = rt3.enter();
}

// If the cycle causes a leak, then miri will catch it.
#[test]
fn drop_tasks_with_reference_cycle() {
rt().block_on(async {
let (tx, mut rx) = mpsc::channel(1);

let barrier = Arc::new(Barrier::new(3));
let barrier_a = barrier.clone();
let barrier_b = barrier.clone();

let a = tokio::spawn(async move {
let b = rx.recv().await.unwrap();

// Poll the JoinHandle once. This registers the waker.
// The other task cannot have finished at this point due to the barrier below.
futures::future::select(b, std::future::ready(())).await;

barrier_a.wait().await;
});

let b = tokio::spawn(async move {
// Poll the JoinHandle once. This registers the waker.
// The other task cannot have finished at this point due to the barrier below.
futures::future::select(a, std::future::ready(())).await;

barrier_b.wait().await;
});

tx.send(b).await.unwrap();

barrier.wait().await;
});
}

#[cfg(tokio_unstable)]
mod unstable {
use super::*;
Expand Down

0 comments on commit 970d880

Please sign in to comment.