From 006b3d7ee2bbd07dced9f91f6b9e001fe954f839 Mon Sep 17 00:00:00 2001 From: Josh McKinney Date: Fri, 3 Jan 2025 02:50:07 -0800 Subject: [PATCH 1/2] stream: add StreamExt::partition() method This allows filtering items that match or don't match into separate stream. It is analogous to std::iter::Iterator::partition() ```rust let stream = stream::iter(0..4); let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); assert_eq!(Some(0), even.next().await); assert_eq!(Some(1), odd.next().await); ``` --- tokio-stream/src/lib.rs | 4 +- tokio-stream/src/stream_ext.rs | 33 +++++ tokio-stream/src/stream_ext/partition.rs | 153 +++++++++++++++++++++++ tokio-stream/tests/stream_partition.rs | 29 +++++ 4 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 tokio-stream/src/stream_ext/partition.rs create mode 100644 tokio-stream/tests/stream_partition.rs diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index f2b463bcb9a..8c52140f6e5 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take, - TakeWhile, Then, + Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, + SkipWhile, Take, TakeWhile, Then, TruePartition, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index cdbada30bc5..875b809e561 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -37,6 +37,10 @@ pub use merge::Merge; mod next; use next::Next; +mod partition; +use partition::Partition; +pub use partition::{FalsePartition, TruePartition}; + mod skip; pub use skip::Skip; @@ -841,6 +845,35 @@ pub trait StreamExt: Stream { FoldFuture::new(self, init, f) } + /// Partitions the stream into two streams based on the provided predicate. + /// + /// The first stream contains all elements for which `f` returns `true`, and + /// the second contains all elements for which `f` returns `false`. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_stream::{self as stream, StreamExt}; + /// + /// let s = stream::iter(vec![1, 2, 3, 4, 5]); + /// let (even, odd) = s.partition(|x| x % 2 == 0); + /// + /// let even: Vec<_> = even.collect().await; + /// let odd: Vec<_> = odd.collect().await; + /// + /// assert_eq!(even, vec![2, 4]); + /// assert_eq!(odd, vec![1, 3, 5]); + /// # } + fn partition(self, f: F) -> (TruePartition, FalsePartition) + where + F: FnMut(&Self::Item) -> bool, + Self: Sized, + { + Partition::new(self, f).split() + } + /// Drain stream pushing all emitted values into a collection. /// /// Equivalent to: diff --git a/tokio-stream/src/stream_ext/partition.rs b/tokio-stream/src/stream_ext/partition.rs new file mode 100644 index 00000000000..17b23799a4e --- /dev/null +++ b/tokio-stream/src/stream_ext/partition.rs @@ -0,0 +1,153 @@ +#![allow(dead_code)] + +use crate::Stream; + +use core::fmt; +use std::{ + collections::VecDeque, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, +}; + +pub(super) struct Partition +where + St: Stream, +{ + stream: St, + f: F, + true_buffer: VecDeque, + false_buffer: VecDeque, + true_waker: Option, + false_waker: Option, +} + +impl Partition +where + St: Stream, +{ + pub(super) fn new(stream: St, f: F) -> Self { + Self { + stream, + f, + true_buffer: VecDeque::new(), + false_buffer: VecDeque::new(), + true_waker: None, + false_waker: None, + } + } + + pub(super) fn split(self) -> (TruePartition, FalsePartition) { + let partition = Arc::new(Mutex::new(self)); + let true_partition = TruePartition::new(partition.clone()); + let false_partition = FalsePartition::new(partition.clone()); + (true_partition, false_partition) + } +} + +/// A stream that only yields elements that satisfy a predicate. +/// +/// This stream is produced by the [`StreamExt::partition`] method. +pub struct TruePartition { + partition: Arc>>, +} + +impl fmt::Debug for TruePartition +where + St: Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TruePartition").finish_non_exhaustive() + } +} + +impl TruePartition +where + St: Stream, +{ + fn new(partition: Arc>>) -> Self { + Self { partition } + } +} + +impl Stream for TruePartition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut partition = self.partition.lock().unwrap(); + if let Some(item) = partition.true_buffer.pop_front() { + return Poll::Ready(Some(item)); + } + + match Pin::new(&mut partition.stream).poll_next(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)), + Poll::Ready(Some(item)) => { + partition.false_buffer.push_back(item); + partition.false_waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => { + partition.true_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} + +/// A stream that only yields elements that do not satisfy a predicate. +/// +/// This stream is produced by the [`StreamExt::partition`] method. +pub struct FalsePartition { + partition: Arc>>, +} + +impl fmt::Debug for FalsePartition +where + St: Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FalsePartition").finish_non_exhaustive() + } +} + +impl FalsePartition { + fn new(partition: Arc>>) -> Self { + Self { partition } + } +} + +impl Stream for FalsePartition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut partition = self.partition.lock().unwrap(); + if let Some(item) = partition.false_buffer.pop_front() { + return Poll::Ready(Some(item)); + } + + match Pin::new(&mut partition.stream).poll_next(cx) { + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)), + Poll::Ready(Some(item)) => { + partition.true_buffer.push_back(item); + partition.true_waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + Poll::Pending => { + partition.false_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } +} diff --git a/tokio-stream/tests/stream_partition.rs b/tokio-stream/tests/stream_partition.rs new file mode 100644 index 00000000000..96a74bf750f --- /dev/null +++ b/tokio-stream/tests/stream_partition.rs @@ -0,0 +1,29 @@ +use tokio_stream::{self as stream, StreamExt}; + +mod support { + pub(crate) mod mpsc; +} + +#[tokio::test] +async fn partition() { + let stream = stream::iter(0..4); + let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); + assert_eq!(Some(0), even.next().await); + assert_eq!(Some(1), odd.next().await); + assert_eq!(Some(2), even.next().await); + assert_eq!(Some(3), odd.next().await); + assert_eq!(None, even.next().await); + assert_eq!(None, odd.next().await); +} + +#[tokio::test] +async fn partition_buffers() { + let stream = stream::iter(0..4); + let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); + assert_eq!(Some(1), odd.next().await); + assert_eq!(Some(3), odd.next().await); + assert_eq!(None, odd.next().await); + assert_eq!(Some(0), even.next().await); + assert_eq!(Some(2), even.next().await); + assert_eq!(None, odd.next().await); +} From ffcf7298de032da9150310d725654e76672f6fbf Mon Sep 17 00:00:00 2001 From: Josh McKinney Date: Fri, 3 Jan 2025 21:45:45 -0800 Subject: [PATCH 2/2] Change to single buffered value A larger buffer of values for each partition can be added on top of a single buffered value by wrapping each stream with code that prefetches the values. Thus this approach has a better composability than using an internal buffer. The tradeoff is that this may deadlock and calling code that needs to concurrently process items in both streams must be programmed in a way that ensures that waiting to consume from one stream never blocks the other stream. --- tokio-stream/src/lib.rs | 4 +- tokio-stream/src/stream_ext.rs | 44 ++++-- tokio-stream/src/stream_ext/partition.rs | 180 +++++++++++------------ tokio-stream/tests/stream_partition.rs | 55 ++++--- 4 files changed, 154 insertions(+), 129 deletions(-) diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index 8c52140f6e5..df0c2c3028f 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, - SkipWhile, Take, TakeWhile, Then, TruePartition, + Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Partition, Peekable, Skip, SkipWhile, + Take, TakeWhile, Then, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index 875b809e561..857114f0715 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -38,8 +38,7 @@ mod next; use next::Next; mod partition; -use partition::Partition; -pub use partition::{FalsePartition, TruePartition}; +pub use partition::Partition; mod skip; pub use skip::Skip; @@ -847,8 +846,26 @@ pub trait StreamExt: Stream { /// Partitions the stream into two streams based on the provided predicate. /// - /// The first stream contains all elements for which `f` returns `true`, and - /// the second contains all elements for which `f` returns `false`. + /// + /// + /// As values of this stream are made available, the provided predicate `f` + /// will be run against them. The first stream contains all elements for + /// which `f` returns `true`, and the second contains all elements for which + /// `f` returns `false`. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped versions of it, similar to [`Iterator::partition`] method in the + /// standard library. + /// + /// # Deadlocks + /// + /// Polling the matching stream when the next value is not a match will not + /// succeed until the value is consumed by polling the non-matching stream. + /// Similarly, polling the non-matching stream when the next value is a + /// match will not succeed until the value is consumed by polling the + /// matching stream. This can lead to a deadlock if the streams are not + /// consumed in a way that allows the other stream to continue (e.g. by + /// using a task or tokio::select! to poll the streams concurrently). /// /// # Examples /// @@ -857,21 +874,22 @@ pub trait StreamExt: Stream { /// # async fn main() { /// use tokio_stream::{self as stream, StreamExt}; /// - /// let s = stream::iter(vec![1, 2, 3, 4, 5]); + /// let s = stream::iter(0..4); /// let (even, odd) = s.partition(|x| x % 2 == 0); /// - /// let even: Vec<_> = even.collect().await; - /// let odd: Vec<_> = odd.collect().await; - /// - /// assert_eq!(even, vec![2, 4]); - /// assert_eq!(odd, vec![1, 3, 5]); + /// assert_eq!(even.next().await, Some(0)); + /// assert_eq!(odd.next().await, Some(1)); + /// assert_eq!(even.next().await, Some(2)); + /// assert_eq!(odd.next().await, Some(3)); + /// assert_eq!(even.next().await, None); + /// assert_eq!(odd.next().await, None); /// # } - fn partition(self, f: F) -> (TruePartition, FalsePartition) + fn partition(self, f: F) -> (Partition, Partition) where F: FnMut(&Self::Item) -> bool, - Self: Sized, + Self: Sized + Unpin, { - Partition::new(self, f).split() + Partition::new(self, f) } /// Drain stream pushing all emitted values into a collection. diff --git a/tokio-stream/src/stream_ext/partition.rs b/tokio-stream/src/stream_ext/partition.rs index 17b23799a4e..159e99eeb3e 100644 --- a/tokio-stream/src/stream_ext/partition.rs +++ b/tokio-stream/src/stream_ext/partition.rs @@ -2,75 +2,53 @@ use crate::Stream; -use core::fmt; use std::{ - collections::VecDeque, + fmt, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, }; -pub(super) struct Partition +/// A stream returned by the [`partition`](super::StreamExt::partition) method. +pub enum Partition where St: Stream, { - stream: St, - f: F, - true_buffer: VecDeque, - false_buffer: VecDeque, - true_waker: Option, - false_waker: Option, + /// A stream that yields items for which the predicate returns `true`. + Matches(Arc>>), + /// A stream that yields items for which the predicate returns `false`. + NonMatches(Arc>>), } -impl Partition +impl fmt::Debug for Partition where - St: Stream, -{ - pub(super) fn new(stream: St, f: F) -> Self { - Self { - stream, - f, - true_buffer: VecDeque::new(), - false_buffer: VecDeque::new(), - true_waker: None, - false_waker: None, - } - } - - pub(super) fn split(self) -> (TruePartition, FalsePartition) { - let partition = Arc::new(Mutex::new(self)); - let true_partition = TruePartition::new(partition.clone()); - let false_partition = FalsePartition::new(partition.clone()); - (true_partition, false_partition) - } -} - -/// A stream that only yields elements that satisfy a predicate. -/// -/// This stream is produced by the [`StreamExt::partition`] method. -pub struct TruePartition { - partition: Arc>>, -} - -impl fmt::Debug for TruePartition -where - St: Stream, + St: fmt::Debug + Stream, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TruePartition").finish_non_exhaustive() + match self { + Partition::Matches(inner) => f.debug_tuple("Partition::Matches").field(inner).finish(), + Partition::NonMatches(inner) => { + f.debug_tuple("Partition::NonMatches").field(inner).finish() + } + } } } -impl TruePartition +impl Partition where - St: Stream, + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, { - fn new(partition: Arc>>) -> Self { - Self { partition } + pub(super) fn new(stream: St, f: F) -> (Self, Self) { + let inner = Arc::new(Mutex::new(Inner::new(stream, f))); + ( + Partition::Matches(inner.clone()), + Partition::NonMatches(inner), + ) } } -impl Stream for TruePartition +impl Stream for Partition where St: Stream + Unpin, F: FnMut(&St::Item) -> bool, @@ -78,74 +56,86 @@ where type Item = St::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut partition = self.partition.lock().unwrap(); - if let Some(item) = partition.true_buffer.pop_front() { - return Poll::Ready(Some(item)); - } - - match Pin::new(&mut partition.stream).poll_next(cx) { - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)), - Poll::Ready(Some(item)) => { - partition.false_buffer.push_back(item); - partition.false_waker = Some(cx.waker().clone()); - cx.waker().wake_by_ref(); - Poll::Pending - } - Poll::Pending => { - partition.true_waker = Some(cx.waker().clone()); - Poll::Pending - } + match self.get_mut() { + Partition::Matches(inner) => inner.lock().unwrap().poll_next(cx, true), + Partition::NonMatches(inner) => inner.lock().unwrap().poll_next(cx, false), } } } -/// A stream that only yields elements that do not satisfy a predicate. -/// -/// This stream is produced by the [`StreamExt::partition`] method. -pub struct FalsePartition { - partition: Arc>>, -} - -impl fmt::Debug for FalsePartition +pub struct Inner where St: Stream, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FalsePartition").finish_non_exhaustive() - } + stream: St, + f: F, + buffered_value: Option>, + waker: Option, } -impl FalsePartition { - fn new(partition: Arc>>) -> Self { - Self { partition } +enum BufferedValue { + Match(T), + NonMatch(T), +} + +impl fmt::Debug for Inner +where + St: fmt::Debug + Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Inner") + .field("stream", &self.stream) + .field("waker", &self.waker) + .finish_non_exhaustive() } } -impl Stream for FalsePartition +impl Inner where St: Stream + Unpin, F: FnMut(&St::Item) -> bool, { - type Item = St::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut partition = self.partition.lock().unwrap(); - if let Some(item) = partition.false_buffer.pop_front() { - return Poll::Ready(Some(item)); + pub(super) fn new(stream: St, f: F) -> Self { + Self { + stream, + f, + buffered_value: None, + waker: None, } + } - match Pin::new(&mut partition.stream).poll_next(cx) { - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)), - Poll::Ready(Some(item)) => { - partition.true_buffer.push_back(item); - partition.true_waker = Some(cx.waker().clone()); - cx.waker().wake_by_ref(); - Poll::Pending + fn poll_next(&mut self, cx: &mut Context<'_>, matches: bool) -> Poll> { + // Check if there is a buffered value + match self.buffered_value.take() { + Some(BufferedValue::Match(value)) if matches => return Poll::Ready(Some(value)), + Some(BufferedValue::NonMatch(value)) if !matches => return Poll::Ready(Some(value)), + Some(value) => { + self.buffered_value = Some(value); + self.waker = Some(cx.waker().clone()); + return Poll::Pending; } + None => {} + } + + // Poll the underlying stream + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(value)) => match (self.f)(&value) { + result if result == matches => Poll::Ready(Some(value)), + true => { + self.buffered_value = Some(BufferedValue::Match(value)); + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + false => { + self.buffered_value = Some(BufferedValue::NonMatch(value)); + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + }, + Poll::Ready(None) => Poll::Ready(None), // Stream is exhausted Poll::Pending => { - partition.false_waker = Some(cx.waker().clone()); + self.waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); Poll::Pending } } diff --git a/tokio-stream/tests/stream_partition.rs b/tokio-stream/tests/stream_partition.rs index 96a74bf750f..b304704fedd 100644 --- a/tokio-stream/tests/stream_partition.rs +++ b/tokio-stream/tests/stream_partition.rs @@ -1,4 +1,5 @@ use tokio_stream::{self as stream, StreamExt}; +use tokio_test::{assert_pending, assert_ready_eq, task}; mod support { pub(crate) mod mpsc; @@ -6,24 +7,40 @@ mod support { #[tokio::test] async fn partition() { - let stream = stream::iter(0..4); - let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); - assert_eq!(Some(0), even.next().await); - assert_eq!(Some(1), odd.next().await); - assert_eq!(Some(2), even.next().await); - assert_eq!(Some(3), odd.next().await); - assert_eq!(None, even.next().await); - assert_eq!(None, odd.next().await); -} + let stream = stream::iter(0..6); + let (matches, non_matches) = stream.partition(|v| v % 2 == 0); + let mut matches = task::spawn(matches); + let mut non_matches = task::spawn(non_matches); -#[tokio::test] -async fn partition_buffers() { - let stream = stream::iter(0..4); - let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); - assert_eq!(Some(1), odd.next().await); - assert_eq!(Some(3), odd.next().await); - assert_eq!(None, odd.next().await); - assert_eq!(Some(0), even.next().await); - assert_eq!(Some(2), even.next().await); - assert_eq!(None, odd.next().await); + // polling matches when the next item matches returns the item from the stream. + assert_ready_eq!(matches.poll_next(), Some(0)); + + // polling non_matches when the next item doesn't match returns the item from the stream. + assert_ready_eq!(non_matches.poll_next(), Some(1)); + + // polling non_matches when the next item matches buffers the item. + assert_pending!(non_matches.poll_next()); + + // polling matches when there is a bufferred match returns the buffered item. + assert_ready_eq!(matches.poll_next(), Some(2)); + + // polling matches when the next item doesn't match buffers the item. + assert_pending!(matches.poll_next()); + + // polling non_matches when there is a bufferred non-match returns the buffered item. + assert_ready_eq!(non_matches.poll_next(), Some(3)); + + // polling non_matches twice when the next item matches buffers the item only once. + assert_pending!(non_matches.poll_next()); + assert_pending!(non_matches.poll_next()); + assert_ready_eq!(matches.poll_next(), Some(4)); + + // polling matches twice when the next item doesn't match buffers the item only once. + assert_pending!(matches.poll_next()); + assert_pending!(matches.poll_next()); + assert_ready_eq!(non_matches.poll_next(), Some(5)); + + // polling matches and non_matches when the stream is exhausted returns None. + assert_ready_eq!(matches.poll_next(), None); + assert_ready_eq!(non_matches.poll_next(), None); }