From ddb1df2d1b9afd94f9a99bd948c9a08efe4554ed Mon Sep 17 00:00:00 2001 From: Evan Rittenhouse Date: Sat, 13 Jul 2024 18:41:29 -0500 Subject: [PATCH] sync: add Sender::closed future --- tokio/src/loom/std/mutex.rs | 2 +- tokio/src/sync/broadcast.rs | 64 ++++++++++++++++++++++++++++++++++- tokio/tests/sync_broadcast.rs | 51 ++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) diff --git a/tokio/src/loom/std/mutex.rs b/tokio/src/loom/std/mutex.rs index 3ea8e1df861..9593ec487d1 100644 --- a/tokio/src/loom/std/mutex.rs +++ b/tokio/src/loom/std/mutex.rs @@ -1,7 +1,7 @@ use std::sync::{self, MutexGuard, TryLockError}; /// Adapter for `std::Mutex` that removes the poisoning aspects -/// from its api. +/// from its API. #[derive(Debug)] pub(crate) struct Mutex(sync::Mutex); diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 56c4cd6b92f..b44f72cf6f6 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -301,6 +301,8 @@ pub mod error { use self::error::{RecvError, SendError, TryRecvError}; +use super::Notify; + /// Data shared between senders and receivers. struct Shared { /// slots in the channel. @@ -314,6 +316,9 @@ struct Shared { /// Number of outstanding Sender handles. num_tx: AtomicUsize, + + /// Notify when the last subscribed [`Receiver`] drops. + notify_last_rx_drop: Notify, } /// Next position to write a value. @@ -528,6 +533,7 @@ impl Sender { waiters: LinkedList::new(), }), num_tx: AtomicUsize::new(1), + notify_last_rx_drop: Notify::new(), }); Sender { shared } @@ -805,6 +811,50 @@ impl Sender { Arc::ptr_eq(&self.shared, &other.shared) } + /// A future which completes when the number of [Receiver]s subscribed to this `Sender` reaches + /// zero. + /// + /// # Examples + /// + /// ``` + /// use futures::FutureExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel::(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// }); + /// + /// let _ = tx.send(10); + /// assert!(tx.closed().now_or_never().is_none()); + /// + /// let _ = tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// }).await; + /// + /// assert!(tx.closed().now_or_never().is_some()); + /// } + /// ``` + pub async fn closed(&self) { + loop { + let notified = self.shared.notify_last_rx_drop.notified(); + + { + // Ensure the lock drops if the channel isn't closed + let tail = self.shared.tail.lock(); + if tail.closed { + return; + } + } + + notified.await; + } + } + fn close_channel(&self) { let mut tail = self.shared.tail.lock(); tail.closed = true; @@ -819,8 +869,14 @@ fn new_receiver(shared: Arc>) -> Receiver { assert!(tail.rx_cnt != MAX_RECEIVERS, "max receivers"); - tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + if tail.rx_cnt == 0 { + // Potentially need to re-open the channel, if a new receiver has been added between calls + // to poll(). Note that we use rx_cnt == 0 instead of is_closed since is_closed also + // applies if the sender has been dropped + tail.closed = false; + } + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); let next = tail.pos; drop(tail); @@ -1346,6 +1402,12 @@ impl Drop for Receiver { tail.rx_cnt -= 1; let until = tail.pos; + let remaining_rx = tail.rx_cnt; + + if remaining_rx == 0 { + self.shared.notify_last_rx_drop.notify_waiters(); + tail.closed = true; + } drop(tail); diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 3af96bdb5d5..2153555694b 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -656,3 +656,54 @@ async fn receiver_recv_is_cooperative() { _ = tokio::task::yield_now() => {}, } } + +#[test] +fn broadcast_sender_closed() { + let (tx, rx) = broadcast::channel::<()>(1); + let rx2 = tx.subscribe(); + + let mut task = task::spawn(tx.closed()); + assert_pending!(task.poll()); + + drop(rx); + assert!(!task.is_woken()); + assert_pending!(task.poll()); + + drop(rx2); + assert!(task.is_woken()); + assert_ready!(task.poll()); +} + +#[test] +fn broadcast_sender_closed_with_extra_subscribe() { + let (tx, rx) = broadcast::channel::<()>(1); + let rx2 = tx.subscribe(); + + let mut task = task::spawn(tx.closed()); + assert_pending!(task.poll()); + + drop(rx); + assert!(!task.is_woken()); + assert_pending!(task.poll()); + + drop(rx2); + assert!(task.is_woken()); + + let rx3 = tx.subscribe(); + assert_pending!(task.poll()); + + drop(rx3); + assert!(task.is_woken()); + assert_ready!(task.poll()); + + let mut task2 = task::spawn(tx.closed()); + assert_ready!(task2.poll()); + + let rx4 = tx.subscribe(); + let mut task3 = task::spawn(tx.closed()); + assert_pending!(task3.poll()); + + drop(rx4); + assert!(task3.is_woken()); + assert_ready!(task3.poll()); +}