Skip to content

Commit

Permalink
sync: add Sender<T>::closed future
Browse files Browse the repository at this point in the history
  • Loading branch information
evanrittenhouse committed Aug 19, 2024
1 parent c8f3539 commit d73005b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tokio/src/loom/std/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::{self, MutexGuard, TryLockError};

/// Adapter for `std::Mutex` that removes the poisoning aspects
/// from its api.
/// from its API.
#[derive(Debug)]
pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>);

Expand Down
68 changes: 66 additions & 2 deletions tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
//! }
//! ```
use crate::future::poll_fn;
use crate::loom::cell::UnsafeCell;
use crate::loom::sync::atomic::{AtomicBool, AtomicUsize};
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
Expand Down Expand Up @@ -163,6 +164,7 @@ use std::task::{Context, Poll, Waker};
/// [`broadcast`]: crate::sync::broadcast
pub struct Sender<T> {
shared: Arc<Shared<T>>,
notify_rx_closed: Arc<Notify>,
}

/// Receiving-half of the [`broadcast`] channel.
Expand Down Expand Up @@ -300,6 +302,8 @@ pub mod error {

use self::error::{RecvError, SendError, TryRecvError};

use super::Notify;

/// Data shared between senders and receivers.
struct Shared<T> {
/// slots in the channel.
Expand All @@ -313,6 +317,9 @@ struct Shared<T> {

/// Number of outstanding Sender handles.
num_tx: AtomicUsize,

/// Notify when a subscribed [`Receiver`] is dropped.
notify_rx_drop: Notify,
}

/// Next position to write a value.
Expand Down Expand Up @@ -527,9 +534,15 @@ impl<T> Sender<T> {
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
notify_rx_drop: Notify::new(),
});

Sender { shared }
let notify_rx_closed = Arc::new(Notify::new());

Sender {
shared,
notify_rx_closed,
}
}

/// Attempts to send a value to all active [`Receiver`] handles, returning
Expand Down Expand Up @@ -804,6 +817,50 @@ impl<T> Sender<T> {
Arc::ptr_eq(&self.shared, &other.shared)
}

/// A future which completes when the number of [Receiver]s subscribed to this `Sender` reaches
/// zero.
///
/// # Examples
///
/// ```
/// use futures::FutureExt;
/// use tokio::sync::broadcast;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx1) = broadcast::channel::<u32>(16);
/// let mut rx2 = tx.subscribe();
///
/// tokio::spawn(async move {
/// assert_eq!(rx1.recv().await.unwrap(), 10);
/// });
///
/// let _ = tx.send(10);
/// assert!(tx.closed().now_or_never().is_none());
///
/// let _ = tokio::spawn(async move {
/// assert_eq!(rx2.recv().await.unwrap(), 10);
/// }).await;
///
/// assert!(tx.closed().now_or_never().is_some());
/// }
/// ```
pub async fn closed(&self) {
self.shared.notify_rx_drop.notified().await;

poll_fn(|_| {
let tail = self.shared.tail.lock();

if tail.closed || tail.rx_cnt == 0 {
return Poll::Ready(());
}

drop(tail);
return Poll::Pending;
})
.await;
}

fn close_channel(&self) {
let mut tail = self.shared.tail.lock();
tail.closed = true;
Expand Down Expand Up @@ -946,7 +1003,12 @@ impl<T> Clone for Sender<T> {
let shared = self.shared.clone();
shared.num_tx.fetch_add(1, SeqCst);

Sender { shared }
let notify_rx_closed = Arc::clone(&self.notify_rx_closed);

Sender {
shared,
notify_rx_closed,
}
}
}

Expand Down Expand Up @@ -1349,6 +1411,8 @@ impl<T> Drop for Receiver<T> {

drop(tail);

self.shared.notify_rx_drop.notify_waiters();

while self.next < until {
match self.recv_ref(None) {
Ok(_) => {}
Expand Down
15 changes: 15 additions & 0 deletions tokio/tests/sync_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,18 @@ fn send_in_waker_drop() {
// Shouldn't deadlock.
let _ = tx.send(());
}

#[test]
fn broadcast_sender_closed() {
let (tx, rx) = broadcast::channel::<()>(1);
let rx2 = tx.subscribe();

let mut task = task::spawn(tx.closed());
assert_pending!(task.poll());

drop(rx);
assert_pending!(task.poll());

drop(rx2);
assert_ready!(task.poll());
}

0 comments on commit d73005b

Please sign in to comment.