Skip to content

Commit

Permalink
sync: extend documentation for watch::Receiver::wait_for (#7038)
Browse files Browse the repository at this point in the history
  • Loading branch information
cip999 authored and joshka committed Jan 3, 2025
1 parent 2052938 commit 8961a53
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 20 deletions.
4 changes: 2 additions & 2 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt};
/// Adapters for [`Stream`]s created by methods in [`StreamExt`].
pub mod adapters {
pub use crate::stream_ext::{
Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take,
TakeWhile, Then,
Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip,
SkipWhile, Take, TakeWhile, Then, TruePartition,
};
cfg_time! {
pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating};
Expand Down
33 changes: 33 additions & 0 deletions tokio-stream/src/stream_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ pub use merge::Merge;
mod next;
use next::Next;

mod partition;
use partition::Partition;
pub use partition::{FalsePartition, TruePartition};

mod skip;
pub use skip::Skip;

Expand Down Expand Up @@ -841,6 +845,35 @@ pub trait StreamExt: Stream {
FoldFuture::new(self, init, f)
}

/// Partitions the stream into two streams based on the provided predicate.
///
/// The first stream contains all elements for which `f` returns `true`, and
/// the second contains all elements for which `f` returns `false`.
///
/// # Examples
///
/// ```
/// # #[tokio::main]
/// # async fn main() {
/// use tokio_stream::{self as stream, StreamExt};
///
/// let s = stream::iter(vec![1, 2, 3, 4, 5]);
/// let (even, odd) = s.partition(|x| x % 2 == 0);
///
/// let even: Vec<_> = even.collect().await;
/// let odd: Vec<_> = odd.collect().await;
///
/// assert_eq!(even, vec![2, 4]);
/// assert_eq!(odd, vec![1, 3, 5]);
/// # }
fn partition<F>(self, f: F) -> (TruePartition<Self, F>, FalsePartition<Self, F>)
where
F: FnMut(&Self::Item) -> bool,
Self: Sized,
{
Partition::new(self, f).split()
}

/// Drain stream pushing all emitted values into a collection.
///
/// Equivalent to:
Expand Down
153 changes: 153 additions & 0 deletions tokio-stream/src/stream_ext/partition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#![allow(dead_code)]

use crate::Stream;

use core::fmt;
use std::{
collections::VecDeque,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};

pub(super) struct Partition<St, F>
where
St: Stream,
{
stream: St,
f: F,
true_buffer: VecDeque<St::Item>,
false_buffer: VecDeque<St::Item>,
true_waker: Option<Waker>,
false_waker: Option<Waker>,
}

impl<St, F> Partition<St, F>
where
St: Stream,
{
pub(super) fn new(stream: St, f: F) -> Self {
Self {
stream,
f,
true_buffer: VecDeque::new(),
false_buffer: VecDeque::new(),
true_waker: None,
false_waker: None,
}
}

pub(super) fn split(self) -> (TruePartition<St, F>, FalsePartition<St, F>) {
let partition = Arc::new(Mutex::new(self));
let true_partition = TruePartition::new(partition.clone());
let false_partition = FalsePartition::new(partition.clone());
(true_partition, false_partition)
}
}

/// A stream that only yields elements that satisfy a predicate.
///
/// This stream is produced by the [`StreamExt::partition`] method.
pub struct TruePartition<St: Stream, F> {
partition: Arc<Mutex<Partition<St, F>>>,
}

impl<St, F> fmt::Debug for TruePartition<St, F>
where
St: Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TruePartition").finish_non_exhaustive()
}
}

impl<St, F> TruePartition<St, F>
where
St: Stream,
{
fn new(partition: Arc<Mutex<Partition<St, F>>>) -> Self {
Self { partition }
}
}

impl<St, F> Stream for TruePartition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut partition = self.partition.lock().unwrap();
if let Some(item) = partition.true_buffer.pop_front() {
return Poll::Ready(Some(item));
}

match Pin::new(&mut partition.stream).poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)),
Poll::Ready(Some(item)) => {
partition.false_buffer.push_back(item);
partition.false_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => {
partition.true_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}

/// A stream that only yields elements that do not satisfy a predicate.
///
/// This stream is produced by the [`StreamExt::partition`] method.
pub struct FalsePartition<St: Stream, F> {
partition: Arc<Mutex<Partition<St, F>>>,
}

impl<St, F> fmt::Debug for FalsePartition<St, F>
where
St: Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FalsePartition").finish_non_exhaustive()
}
}

impl<St: Stream, F> FalsePartition<St, F> {
fn new(partition: Arc<Mutex<Partition<St, F>>>) -> Self {
Self { partition }
}
}

impl<St, F> Stream for FalsePartition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut partition = self.partition.lock().unwrap();
if let Some(item) = partition.false_buffer.pop_front() {
return Poll::Ready(Some(item));
}

match Pin::new(&mut partition.stream).poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)),
Poll::Ready(Some(item)) => {
partition.true_buffer.push_back(item);
partition.true_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => {
partition.false_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
29 changes: 29 additions & 0 deletions tokio-stream/tests/stream_partition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use tokio_stream::{self as stream, StreamExt};

mod support {
pub(crate) mod mpsc;
}

#[tokio::test]
async fn partition() {
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(1), odd.next().await);
assert_eq!(Some(2), even.next().await);
assert_eq!(Some(3), odd.next().await);
assert_eq!(None, even.next().await);
assert_eq!(None, odd.next().await);
}

#[tokio::test]
async fn partition_buffers() {
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(1), odd.next().await);
assert_eq!(Some(3), odd.next().await);
assert_eq!(None, odd.next().await);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(2), even.next().await);
assert_eq!(None, odd.next().await);
}
43 changes: 25 additions & 18 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,44 +767,51 @@ impl<T> Receiver<T> {
/// When this function returns, the value that was passed to the closure
/// when it returned `true` will be considered seen.
///
/// If the channel is closed, then `wait_for` will return a `RecvError`.
/// If the channel is closed, then `wait_for` will return a [`RecvError`].
/// Once this happens, no more messages can ever be sent on the channel.
/// When an error is returned, it is guaranteed that the closure has been
/// called on the last value, and that it returned `false` for that value.
/// (If the closure returned `true`, then the last value would have been
/// returned instead of the error.)
///
/// Like the `borrow` method, the returned borrow holds a read lock on the
/// Like the [`borrow`] method, the returned borrow holds a read lock on the
/// inner value. This means that long-lived borrows could cause the producer
/// half to block. It is recommended to keep the borrow as short-lived as
/// possible. See the documentation of `borrow` for more information on
/// this.
///
/// [`Receiver::changed()`]: crate::sync::watch::Receiver::changed
/// [`borrow`]: Receiver::borrow
/// [`RecvError`]: error::RecvError
///
/// # Cancel safety
///
/// This method is cancel safe. If you use it as the event in a
/// [`tokio::select!`](crate::select) statement and some other branch
/// completes first, then it is guaranteed that the last seen value `val`
/// (if any) satisfies `f(val) == false`.
///
/// # Panics
///
/// If and only if the closure `f` panics. In that case, no resource owned
/// or shared by this [`Receiver`] will be poisoned.
///
/// # Examples
///
/// ```
/// use tokio::sync::watch;
/// use tokio::time::{sleep, Duration};
///
/// #[tokio::main]
///
/// #[tokio::main(flavor = "current_thread", start_paused = true)]
/// async fn main() {
/// let (tx, _rx) = watch::channel("hello");
/// let (tx, mut rx) = watch::channel("hello");
///
/// tx.send("goodbye").unwrap();
/// tokio::spawn(async move {
/// sleep(Duration::from_secs(1)).await;
/// tx.send("goodbye").unwrap();
/// });
///
/// // here we subscribe to a second receiver
/// // now in case of using `changed` we would have
/// // to first check the current value and then wait
/// // for changes or else `changed` would hang.
/// let mut rx2 = tx.subscribe();
///
/// // in place of changed we have use `wait_for`
/// // which would automatically check the current value
/// // and wait for changes until the closure returns true.
/// assert!(rx2.wait_for(|val| *val == "goodbye").await.is_ok());
/// assert_eq!(*rx2.borrow(), "goodbye");
/// assert!(rx.wait_for(|val| *val == "goodbye").await.is_ok());
/// assert_eq!(*rx.borrow(), "goodbye");
/// }
/// ```
pub async fn wait_for(
Expand Down

0 comments on commit 8961a53

Please sign in to comment.