diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index f2b463bcb9a..df0c2c3028f 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take, - TakeWhile, Then, + Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Partition, Peekable, Skip, SkipWhile, + Take, TakeWhile, Then, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index cdbada30bc5..857114f0715 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -37,6 +37,9 @@ pub use merge::Merge; mod next; use next::Next; +mod partition; +pub use partition::Partition; + mod skip; pub use skip::Skip; @@ -841,6 +844,54 @@ pub trait StreamExt: Stream { FoldFuture::new(self, init, f) } + /// Partitions the stream into two streams based on the provided predicate. + /// + /// + /// + /// As values of this stream are made available, the provided predicate `f` + /// will be run against them. The first stream contains all elements for + /// which `f` returns `true`, and the second contains all elements for which + /// `f` returns `false`. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped versions of it, similar to [`Iterator::partition`] method in the + /// standard library. + /// + /// # Deadlocks + /// + /// Polling the matching stream when the next value is not a match will not + /// succeed until the value is consumed by polling the non-matching stream. + /// Similarly, polling the non-matching stream when the next value is a + /// match will not succeed until the value is consumed by polling the + /// matching stream. This can lead to a deadlock if the streams are not + /// consumed in a way that allows the other stream to continue (e.g. by + /// using a task or tokio::select! to poll the streams concurrently). + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_stream::{self as stream, StreamExt}; + /// + /// let s = stream::iter(0..4); + /// let (even, odd) = s.partition(|x| x % 2 == 0); + /// + /// assert_eq!(even.next().await, Some(0)); + /// assert_eq!(odd.next().await, Some(1)); + /// assert_eq!(even.next().await, Some(2)); + /// assert_eq!(odd.next().await, Some(3)); + /// assert_eq!(even.next().await, None); + /// assert_eq!(odd.next().await, None); + /// # } + fn partition(self, f: F) -> (Partition, Partition) + where + F: FnMut(&Self::Item) -> bool, + Self: Sized + Unpin, + { + Partition::new(self, f) + } + /// Drain stream pushing all emitted values into a collection. /// /// Equivalent to: diff --git a/tokio-stream/src/stream_ext/partition.rs b/tokio-stream/src/stream_ext/partition.rs new file mode 100644 index 00000000000..159e99eeb3e --- /dev/null +++ b/tokio-stream/src/stream_ext/partition.rs @@ -0,0 +1,143 @@ +#![allow(dead_code)] + +use crate::Stream; + +use std::{ + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, +}; + +/// A stream returned by the [`partition`](super::StreamExt::partition) method. +pub enum Partition +where + St: Stream, +{ + /// A stream that yields items for which the predicate returns `true`. + Matches(Arc>>), + /// A stream that yields items for which the predicate returns `false`. + NonMatches(Arc>>), +} + +impl fmt::Debug for Partition +where + St: fmt::Debug + Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Partition::Matches(inner) => f.debug_tuple("Partition::Matches").field(inner).finish(), + Partition::NonMatches(inner) => { + f.debug_tuple("Partition::NonMatches").field(inner).finish() + } + } + } +} + +impl Partition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + pub(super) fn new(stream: St, f: F) -> (Self, Self) { + let inner = Arc::new(Mutex::new(Inner::new(stream, f))); + ( + Partition::Matches(inner.clone()), + Partition::NonMatches(inner), + ) + } +} + +impl Stream for Partition +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + type Item = St::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Partition::Matches(inner) => inner.lock().unwrap().poll_next(cx, true), + Partition::NonMatches(inner) => inner.lock().unwrap().poll_next(cx, false), + } + } +} + +pub struct Inner +where + St: Stream, +{ + stream: St, + f: F, + buffered_value: Option>, + waker: Option, +} + +enum BufferedValue { + Match(T), + NonMatch(T), +} + +impl fmt::Debug for Inner +where + St: fmt::Debug + Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Inner") + .field("stream", &self.stream) + .field("waker", &self.waker) + .finish_non_exhaustive() + } +} + +impl Inner +where + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, +{ + pub(super) fn new(stream: St, f: F) -> Self { + Self { + stream, + f, + buffered_value: None, + waker: None, + } + } + + fn poll_next(&mut self, cx: &mut Context<'_>, matches: bool) -> Poll> { + // Check if there is a buffered value + match self.buffered_value.take() { + Some(BufferedValue::Match(value)) if matches => return Poll::Ready(Some(value)), + Some(BufferedValue::NonMatch(value)) if !matches => return Poll::Ready(Some(value)), + Some(value) => { + self.buffered_value = Some(value); + self.waker = Some(cx.waker().clone()); + return Poll::Pending; + } + None => {} + } + + // Poll the underlying stream + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(value)) => match (self.f)(&value) { + result if result == matches => Poll::Ready(Some(value)), + true => { + self.buffered_value = Some(BufferedValue::Match(value)); + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + false => { + self.buffered_value = Some(BufferedValue::NonMatch(value)); + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + }, + Poll::Ready(None) => Poll::Ready(None), // Stream is exhausted + Poll::Pending => { + self.waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} diff --git a/tokio-stream/tests/stream_partition.rs b/tokio-stream/tests/stream_partition.rs new file mode 100644 index 00000000000..b304704fedd --- /dev/null +++ b/tokio-stream/tests/stream_partition.rs @@ -0,0 +1,46 @@ +use tokio_stream::{self as stream, StreamExt}; +use tokio_test::{assert_pending, assert_ready_eq, task}; + +mod support { + pub(crate) mod mpsc; +} + +#[tokio::test] +async fn partition() { + let stream = stream::iter(0..6); + let (matches, non_matches) = stream.partition(|v| v % 2 == 0); + let mut matches = task::spawn(matches); + let mut non_matches = task::spawn(non_matches); + + // polling matches when the next item matches returns the item from the stream. + assert_ready_eq!(matches.poll_next(), Some(0)); + + // polling non_matches when the next item doesn't match returns the item from the stream. + assert_ready_eq!(non_matches.poll_next(), Some(1)); + + // polling non_matches when the next item matches buffers the item. + assert_pending!(non_matches.poll_next()); + + // polling matches when there is a bufferred match returns the buffered item. + assert_ready_eq!(matches.poll_next(), Some(2)); + + // polling matches when the next item doesn't match buffers the item. + assert_pending!(matches.poll_next()); + + // polling non_matches when there is a bufferred non-match returns the buffered item. + assert_ready_eq!(non_matches.poll_next(), Some(3)); + + // polling non_matches twice when the next item matches buffers the item only once. + assert_pending!(non_matches.poll_next()); + assert_pending!(non_matches.poll_next()); + assert_ready_eq!(matches.poll_next(), Some(4)); + + // polling matches twice when the next item doesn't match buffers the item only once. + assert_pending!(matches.poll_next()); + assert_pending!(matches.poll_next()); + assert_ready_eq!(non_matches.poll_next(), Some(5)); + + // polling matches and non_matches when the stream is exhausted returns None. + assert_ready_eq!(matches.poll_next(), None); + assert_ready_eq!(non_matches.poll_next(), None); +}