diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index f2b463bcb9a..8c52140f6e5 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take, - TakeWhile, Then, + Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, + SkipWhile, Take, TakeWhile, Then, TruePartition, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index cdbada30bc5..875b809e561 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -37,6 +37,10 @@ pub use merge::Merge; mod next; use next::Next; +mod partition; +use partition::Partition; +pub use partition::{FalsePartition, TruePartition}; + mod skip; pub use skip::Skip; @@ -841,6 +845,35 @@ pub trait StreamExt: Stream { FoldFuture::new(self, init, f) } + /// Partitions the stream into two streams based on the provided predicate. + /// + /// The first stream contains all elements for which `f` returns `true`, and + /// the second contains all elements for which `f` returns `false`. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_stream::{self as stream, StreamExt}; + /// + /// let s = stream::iter(vec![1, 2, 3, 4, 5]); + /// let (even, odd) = s.partition(|x| x % 2 == 0); + /// + /// let even: Vec<_> = even.collect().await; + /// let odd: Vec<_> = odd.collect().await; + /// + /// assert_eq!(even, vec![2, 4]); + /// assert_eq!(odd, vec![1, 3, 5]); + /// # } + fn partition(self, f: F) -> (TruePartition, FalsePartition) + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + Partition::new(self, f).split() + } + /// Drain stream pushing all emitted values into a collection. /// /// Equivalent to: diff --git a/tokio-stream/src/stream_ext/partition.rs b/tokio-stream/src/stream_ext/partition.rs new file mode 100644 index 00000000000..17b23799a4e --- /dev/null +++ b/tokio-stream/src/stream_ext/partition.rs @@ -0,0 +1,153 @@ +#![allow(dead_code)] + +use crate::Stream; + +use core::fmt; +use std::{ + collections::VecDeque, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, +}; + +pub(super) struct Partition +where + St: Stream, +{ + stream: St, + f: F, + true_buffer: VecDeque, + false_buffer: VecDeque, + true_waker: Option, + false_waker: Option, +} + +impl Partition +where + St: Stream, +{ + pub(super) fn new(stream: St, f: F) -> Self { + Self { + stream, + f, + true_buffer: VecDeque::new(), + false_buffer: VecDeque::new(), + true_waker: None, + false_waker: None, + } + } + + pub(super) fn split(self) -> (TruePartition, FalsePartition) { + let partition = Arc::new(Mutex::new(self)); + let true_partition = TruePartition::new(partition.clone()); + let false_partition = FalsePartition::new(partition.clone()); + (true_partition, false_partition) + } +} + +/// A stream that only yields elements that satisfy a predicate. +/// +/// This stream is produced by the [`StreamExt::partition`] method. +pub struct TruePartition { + partition: Arc>>, +} + +impl fmt::Debug for TruePartition +where + St: Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TruePartition").finish_non_exhaustive() + } +} + +impl TruePartition +where + St: Stream, +{ + fn new(partition: Arc>>) -> Self { + Self { partition } + } +} + +impl Stream for TruePartition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut partition = self.partition.lock().unwrap(); + if let Some(item) = partition.true_buffer.pop_front() { + return Poll::Ready(Some(item)); + } + + match Pin::new(&mut partition.stream).poll_next(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)), + Poll::Ready(Some(item)) => { + partition.false_buffer.push_back(item); + partition.false_waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => { + partition.true_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} + +/// A stream that only yields elements that do not satisfy a predicate. +/// +/// This stream is produced by the [`StreamExt::partition`] method. +pub struct FalsePartition { + partition: Arc>>, +} + +impl fmt::Debug for FalsePartition +where + St: Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FalsePartition").finish_non_exhaustive() + } +} + +impl FalsePartition { + fn new(partition: Arc>>) -> Self { + Self { partition } + } +} + +impl Stream for FalsePartition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut partition = self.partition.lock().unwrap(); + if let Some(item) = partition.false_buffer.pop_front() { + return Poll::Ready(Some(item)); + } + + match Pin::new(&mut partition.stream).poll_next(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)), + Poll::Ready(Some(item)) => { + partition.true_buffer.push_back(item); + partition.true_waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => { + partition.false_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} diff --git a/tokio-stream/tests/stream_partition.rs b/tokio-stream/tests/stream_partition.rs new file mode 100644 index 00000000000..96a74bf750f --- /dev/null +++ b/tokio-stream/tests/stream_partition.rs @@ -0,0 +1,29 @@ +use tokio_stream::{self as stream, StreamExt}; + +mod support { + pub(crate) mod mpsc; +} + +#[tokio::test] +async fn partition() { + let stream = stream::iter(0..4); + let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); + assert_eq!(Some(0), even.next().await); + assert_eq!(Some(1), odd.next().await); + assert_eq!(Some(2), even.next().await); + assert_eq!(Some(3), odd.next().await); + assert_eq!(None, even.next().await); + assert_eq!(None, odd.next().await); +} + +#[tokio::test] +async fn partition_buffers() { + let stream = stream::iter(0..4); + let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); + assert_eq!(Some(1), odd.next().await); + assert_eq!(Some(3), odd.next().await); + assert_eq!(None, odd.next().await); + assert_eq!(Some(0), even.next().await); + assert_eq!(Some(2), even.next().await); + assert_eq!(None, odd.next().await); +}