diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index f2b463bcb9a..8c52140f6e5 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take, - TakeWhile, Then, + Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, + SkipWhile, Take, TakeWhile, Then, TruePartition, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index cdbada30bc5..875b809e561 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -37,6 +37,10 @@ pub use merge::Merge; mod next; use next::Next; +mod partition; +use partition::Partition; +pub use partition::{FalsePartition, TruePartition}; + mod skip; pub use skip::Skip; @@ -841,6 +845,35 @@ pub trait StreamExt: Stream { FoldFuture::new(self, init, f) } + /// Partitions the stream into two streams based on the provided predicate. + /// + /// The first stream contains all elements for which `f` returns `true`, and + /// the second contains all elements for which `f` returns `false`. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_stream::{self as stream, StreamExt}; + /// + /// let s = stream::iter(vec![1, 2, 3, 4, 5]); + /// let (even, odd) = s.partition(|x| x % 2 == 0); + /// + /// let even: Vec<_> = even.collect().await; + /// let odd: Vec<_> = odd.collect().await; + /// + /// assert_eq!(even, vec![2, 4]); + /// assert_eq!(odd, vec![1, 3, 5]); + /// # } + fn partition(self, f: F) -> (TruePartition, FalsePartition) + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + Partition::new(self, f).split() + } + /// Drain stream pushing all emitted values into a collection. /// /// Equivalent to: diff --git a/tokio-stream/src/stream_ext/partition.rs b/tokio-stream/src/stream_ext/partition.rs new file mode 100644 index 00000000000..17b23799a4e --- /dev/null +++ b/tokio-stream/src/stream_ext/partition.rs @@ -0,0 +1,153 @@ +#![allow(dead_code)] + +use crate::Stream; + +use core::fmt; +use std::{ + collections::VecDeque, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, +}; + +pub(super) struct Partition +where + St: Stream, +{ + stream: St, + f: F, + true_buffer: VecDeque, + false_buffer: VecDeque, + true_waker: Option, + false_waker: Option, +} + +impl Partition +where + St: Stream, +{ + pub(super) fn new(stream: St, f: F) -> Self { + Self { + stream, + f, + true_buffer: VecDeque::new(), + false_buffer: VecDeque::new(), + true_waker: None, + false_waker: None, + } + } + + pub(super) fn split(self) -> (TruePartition, FalsePartition) { + let partition = Arc::new(Mutex::new(self)); + let true_partition = TruePartition::new(partition.clone()); + let false_partition = FalsePartition::new(partition.clone()); + (true_partition, false_partition) + } +} + +/// A stream that only yields elements that satisfy a predicate. +/// +/// This stream is produced by the [`StreamExt::partition`] method. +pub struct TruePartition { + partition: Arc>>, +} + +impl fmt::Debug for TruePartition +where + St: Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TruePartition").finish_non_exhaustive() + } +} + +impl TruePartition +where + St: Stream, +{ + fn new(partition: Arc>>) -> Self { + Self { partition } + } +} + +impl Stream for TruePartition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut partition = self.partition.lock().unwrap(); + if let Some(item) = partition.true_buffer.pop_front() { + return Poll::Ready(Some(item)); + } + + match Pin::new(&mut partition.stream).poll_next(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)), + Poll::Ready(Some(item)) => { + partition.false_buffer.push_back(item); + partition.false_waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => { + partition.true_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} + +/// A stream that only yields elements that do not satisfy a predicate. +/// +/// This stream is produced by the [`StreamExt::partition`] method. +pub struct FalsePartition { + partition: Arc>>, +} + +impl fmt::Debug for FalsePartition +where + St: Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FalsePartition").finish_non_exhaustive() + } +} + +impl FalsePartition { + fn new(partition: Arc>>) -> Self { + Self { partition } + } +} + +impl Stream for FalsePartition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut partition = self.partition.lock().unwrap(); + if let Some(item) = partition.false_buffer.pop_front() { + return Poll::Ready(Some(item)); + } + + match Pin::new(&mut partition.stream).poll_next(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)), + Poll::Ready(Some(item)) => { + partition.true_buffer.push_back(item); + partition.true_waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => { + partition.false_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} diff --git a/tokio-stream/tests/stream_partition.rs b/tokio-stream/tests/stream_partition.rs new file mode 100644 index 00000000000..96a74bf750f --- /dev/null +++ b/tokio-stream/tests/stream_partition.rs @@ -0,0 +1,29 @@ +use tokio_stream::{self as stream, StreamExt}; + +mod support { + pub(crate) mod mpsc; +} + +#[tokio::test] +async fn partition() { + let stream = stream::iter(0..4); + let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); + assert_eq!(Some(0), even.next().await); + assert_eq!(Some(1), odd.next().await); + assert_eq!(Some(2), even.next().await); + assert_eq!(Some(3), odd.next().await); + assert_eq!(None, even.next().await); + assert_eq!(None, odd.next().await); +} + +#[tokio::test] +async fn partition_buffers() { + let stream = stream::iter(0..4); + let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); + assert_eq!(Some(1), odd.next().await); + assert_eq!(Some(3), odd.next().await); + assert_eq!(None, odd.next().await); + assert_eq!(Some(0), even.next().await); + assert_eq!(Some(2), even.next().await); + assert_eq!(None, odd.next().await); +} diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 7d042a6f950..0f3bafff889 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -767,44 +767,51 @@ impl Receiver { /// When this function returns, the value that was passed to the closure /// when it returned `true` will be considered seen. /// - /// If the channel is closed, then `wait_for` will return a `RecvError`. + /// If the channel is closed, then `wait_for` will return a [`RecvError`]. /// Once this happens, no more messages can ever be sent on the channel. /// When an error is returned, it is guaranteed that the closure has been /// called on the last value, and that it returned `false` for that value. /// (If the closure returned `true`, then the last value would have been /// returned instead of the error.) /// - /// Like the `borrow` method, the returned borrow holds a read lock on the + /// Like the [`borrow`] method, the returned borrow holds a read lock on the /// inner value. This means that long-lived borrows could cause the producer /// half to block. It is recommended to keep the borrow as short-lived as /// possible. See the documentation of `borrow` for more information on /// this. /// - /// [`Receiver::changed()`]: crate::sync::watch::Receiver::changed + /// [`borrow`]: Receiver::borrow + /// [`RecvError`]: error::RecvError + /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the last seen value `val` + /// (if any) satisfies `f(val) == false`. + /// + /// # Panics + /// + /// If and only if the closure `f` panics. In that case, no resource owned + /// or shared by this [`Receiver`] will be poisoned. /// /// # Examples /// /// ``` /// use tokio::sync::watch; + /// use tokio::time::{sleep, Duration}; /// - /// #[tokio::main] - /// + /// #[tokio::main(flavor = "current_thread", start_paused = true)] /// async fn main() { - /// let (tx, _rx) = watch::channel("hello"); + /// let (tx, mut rx) = watch::channel("hello"); /// - /// tx.send("goodbye").unwrap(); + /// tokio::spawn(async move { + /// sleep(Duration::from_secs(1)).await; + /// tx.send("goodbye").unwrap(); + /// }); /// - /// // here we subscribe to a second receiver - /// // now in case of using `changed` we would have - /// // to first check the current value and then wait - /// // for changes or else `changed` would hang. - /// let mut rx2 = tx.subscribe(); - /// - /// // in place of changed we have use `wait_for` - /// // which would automatically check the current value - /// // and wait for changes until the closure returns true. - /// assert!(rx2.wait_for(|val| *val == "goodbye").await.is_ok()); - /// assert_eq!(*rx2.borrow(), "goodbye"); + /// assert!(rx.wait_for(|val| *val == "goodbye").await.is_ok()); + /// assert_eq!(*rx.borrow(), "goodbye"); /// } /// ``` pub async fn wait_for(