Skip to content

Commit

Permalink
sync: add Sender<T>::closed future
Browse files Browse the repository at this point in the history
  • Loading branch information
evanrittenhouse committed Dec 27, 2024
1 parent 0dbdd19 commit ddb1df2
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tokio/src/loom/std/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::{self, MutexGuard, TryLockError};

/// Adapter for `std::Mutex` that removes the poisoning aspects
/// from its api.
/// from its API.
#[derive(Debug)]
pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>);

Expand Down
64 changes: 63 additions & 1 deletion tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ pub mod error {

use self::error::{RecvError, SendError, TryRecvError};

use super::Notify;

/// Data shared between senders and receivers.
struct Shared<T> {
/// slots in the channel.
Expand All @@ -314,6 +316,9 @@ struct Shared<T> {

/// Number of outstanding Sender handles.
num_tx: AtomicUsize,

/// Notify when the last subscribed [`Receiver`] drops.
notify_last_rx_drop: Notify,
}

/// Next position to write a value.
Expand Down Expand Up @@ -528,6 +533,7 @@ impl<T> Sender<T> {
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
notify_last_rx_drop: Notify::new(),
});

Sender { shared }
Expand Down Expand Up @@ -805,6 +811,50 @@ impl<T> Sender<T> {
Arc::ptr_eq(&self.shared, &other.shared)
}

/// A future which completes when the number of [Receiver]s subscribed to this `Sender` reaches
/// zero.
///
/// # Examples
///
/// ```
/// use futures::FutureExt;
/// use tokio::sync::broadcast;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx1) = broadcast::channel::<u32>(16);
/// let mut rx2 = tx.subscribe();
///
/// tokio::spawn(async move {
/// assert_eq!(rx1.recv().await.unwrap(), 10);
/// });
///
/// let _ = tx.send(10);
/// assert!(tx.closed().now_or_never().is_none());
///
/// let _ = tokio::spawn(async move {
/// assert_eq!(rx2.recv().await.unwrap(), 10);
/// }).await;
///
/// assert!(tx.closed().now_or_never().is_some());
/// }
/// ```
pub async fn closed(&self) {
loop {
let notified = self.shared.notify_last_rx_drop.notified();

{
// Ensure the lock drops if the channel isn't closed
let tail = self.shared.tail.lock();
if tail.closed {
return;
}
}

notified.await;
}
}

fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
tail.closed = true;
Expand All @@ -819,8 +869,14 @@ fn new_receiver<T>(shared: Arc<Shared<T>>) -> Receiver<T> {

assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers");

tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
if tail.rx_cnt == 0 {
// Potentially need to re-open the channel, if a new receiver has been added between calls
// to poll(). Note that we use rx_cnt == 0 instead of is_closed since is_closed also
// applies if the sender has been dropped
tail.closed = false;
}

tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow");
let next = tail.pos;

drop(tail);
Expand Down Expand Up @@ -1346,6 +1402,12 @@ impl<T> Drop for Receiver<T> {

tail.rx_cnt -= 1;
let until = tail.pos;
let remaining_rx = tail.rx_cnt;

if remaining_rx == 0 {
self.shared.notify_last_rx_drop.notify_waiters();
tail.closed = true;
}

drop(tail);

Expand Down
51 changes: 51 additions & 0 deletions tokio/tests/sync_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,3 +656,54 @@ async fn receiver_recv_is_cooperative() {
_ = tokio::task::yield_now() => {},
}
}

#[test]
fn broadcast_sender_closed() {
let (tx, rx) = broadcast::channel::<()>(1);
let rx2 = tx.subscribe();

let mut task = task::spawn(tx.closed());
assert_pending!(task.poll());

drop(rx);
assert!(!task.is_woken());
assert_pending!(task.poll());

drop(rx2);
assert!(task.is_woken());
assert_ready!(task.poll());
}

#[test]
fn broadcast_sender_closed_with_extra_subscribe() {
let (tx, rx) = broadcast::channel::<()>(1);
let rx2 = tx.subscribe();

let mut task = task::spawn(tx.closed());
assert_pending!(task.poll());

drop(rx);
assert!(!task.is_woken());
assert_pending!(task.poll());

drop(rx2);
assert!(task.is_woken());

let rx3 = tx.subscribe();
assert_pending!(task.poll());

drop(rx3);
assert!(task.is_woken());
assert_ready!(task.poll());

let mut task2 = task::spawn(tx.closed());
assert_ready!(task2.poll());

let rx4 = tx.subscribe();
let mut task3 = task::spawn(tx.closed());
assert_pending!(task3.poll());

drop(rx4);
assert!(task3.is_woken());
assert_ready!(task3.poll());
}

0 comments on commit ddb1df2

Please sign in to comment.