diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index 8c52140f6e5..df0c2c3028f 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt}; /// Adapters for [`Stream`]s created by methods in [`StreamExt`]. pub mod adapters { pub use crate::stream_ext::{ - Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, - SkipWhile, Take, TakeWhile, Then, TruePartition, + Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Partition, Peekable, Skip, SkipWhile, + Take, TakeWhile, Then, }; cfg_time! { pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index 875b809e561..857114f0715 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -38,8 +38,7 @@ mod next; use next::Next; mod partition; -use partition::Partition; -pub use partition::{FalsePartition, TruePartition}; +pub use partition::Partition; mod skip; pub use skip::Skip; @@ -847,8 +846,26 @@ pub trait StreamExt: Stream { /// Partitions the stream into two streams based on the provided predicate. /// - /// The first stream contains all elements for which `f` returns `true`, and - /// the second contains all elements for which `f` returns `false`. + /// + /// + /// As values of this stream are made available, the provided predicate `f` + /// will be run against them. The first stream contains all elements for + /// which `f` returns `true`, and the second contains all elements for which + /// `f` returns `false`. + /// + /// Note that this function consumes the stream passed into it and returns a + /// wrapped versions of it, similar to [`Iterator::partition`] method in the + /// standard library. + /// + /// # Deadlocks + /// + /// Polling the matching stream when the next value is not a match will not + /// succeed until the value is consumed by polling the non-matching stream. + /// Similarly, polling the non-matching stream when the next value is a + /// match will not succeed until the value is consumed by polling the + /// matching stream. This can lead to a deadlock if the streams are not + /// consumed in a way that allows the other stream to continue (e.g. by + /// using a task or tokio::select! to poll the streams concurrently). /// /// # Examples /// @@ -857,21 +874,22 @@ pub trait StreamExt: Stream { /// # async fn main() { /// use tokio_stream::{self as stream, StreamExt}; /// - /// let s = stream::iter(vec![1, 2, 3, 4, 5]); + /// let s = stream::iter(0..4); /// let (even, odd) = s.partition(|x| x % 2 == 0); /// - /// let even: Vec<_> = even.collect().await; - /// let odd: Vec<_> = odd.collect().await; - /// - /// assert_eq!(even, vec![2, 4]); - /// assert_eq!(odd, vec![1, 3, 5]); + /// assert_eq!(even.next().await, Some(0)); + /// assert_eq!(odd.next().await, Some(1)); + /// assert_eq!(even.next().await, Some(2)); + /// assert_eq!(odd.next().await, Some(3)); + /// assert_eq!(even.next().await, None); + /// assert_eq!(odd.next().await, None); /// # } - fn partition(self, f: F) -> (TruePartition, FalsePartition) + fn partition(self, f: F) -> (Partition, Partition) where F: FnMut(&Self::Item) -> bool, - Self: Sized, + Self: Sized + Unpin, { - Partition::new(self, f).split() + Partition::new(self, f) } /// Drain stream pushing all emitted values into a collection. diff --git a/tokio-stream/src/stream_ext/partition.rs b/tokio-stream/src/stream_ext/partition.rs index 17b23799a4e..159e99eeb3e 100644 --- a/tokio-stream/src/stream_ext/partition.rs +++ b/tokio-stream/src/stream_ext/partition.rs @@ -2,75 +2,53 @@ use crate::Stream; -use core::fmt; use std::{ - collections::VecDeque, + fmt, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, }; -pub(super) struct Partition +/// A stream returned by the [`partition`](super::StreamExt::partition) method. +pub enum Partition where St: Stream, { - stream: St, - f: F, - true_buffer: VecDeque, - false_buffer: VecDeque, - true_waker: Option, - false_waker: Option, + /// A stream that yields items for which the predicate returns `true`. + Matches(Arc>>), + /// A stream that yields items for which the predicate returns `false`. + NonMatches(Arc>>), } -impl Partition +impl fmt::Debug for Partition where - St: Stream, -{ - pub(super) fn new(stream: St, f: F) -> Self { - Self { - stream, - f, - true_buffer: VecDeque::new(), - false_buffer: VecDeque::new(), - true_waker: None, - false_waker: None, - } - } - - pub(super) fn split(self) -> (TruePartition, FalsePartition) { - let partition = Arc::new(Mutex::new(self)); - let true_partition = TruePartition::new(partition.clone()); - let false_partition = FalsePartition::new(partition.clone()); - (true_partition, false_partition) - } -} - -/// A stream that only yields elements that satisfy a predicate. -/// -/// This stream is produced by the [`StreamExt::partition`] method. -pub struct TruePartition { - partition: Arc>>, -} - -impl fmt::Debug for TruePartition -where - St: Stream, + St: fmt::Debug + Stream, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TruePartition").finish_non_exhaustive() + match self { + Partition::Matches(inner) => f.debug_tuple("Partition::Matches").field(inner).finish(), + Partition::NonMatches(inner) => { + f.debug_tuple("Partition::NonMatches").field(inner).finish() + } + } } } -impl TruePartition +impl Partition where - St: Stream, + St: Stream + Unpin, + F: FnMut(&St::Item) -> bool, { - fn new(partition: Arc>>) -> Self { - Self { partition } + pub(super) fn new(stream: St, f: F) -> (Self, Self) { + let inner = Arc::new(Mutex::new(Inner::new(stream, f))); + ( + Partition::Matches(inner.clone()), + Partition::NonMatches(inner), + ) } } -impl Stream for TruePartition +impl Stream for Partition where St: Stream + Unpin, F: FnMut(&St::Item) -> bool, @@ -78,74 +56,86 @@ where type Item = St::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut partition = self.partition.lock().unwrap(); - if let Some(item) = partition.true_buffer.pop_front() { - return Poll::Ready(Some(item)); - } - - match Pin::new(&mut partition.stream).poll_next(cx) { - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)), - Poll::Ready(Some(item)) => { - partition.false_buffer.push_back(item); - partition.false_waker = Some(cx.waker().clone()); - cx.waker().wake_by_ref(); - Poll::Pending - } - Poll::Pending => { - partition.true_waker = Some(cx.waker().clone()); - Poll::Pending - } + match self.get_mut() { + Partition::Matches(inner) => inner.lock().unwrap().poll_next(cx, true), + Partition::NonMatches(inner) => inner.lock().unwrap().poll_next(cx, false), } } } -/// A stream that only yields elements that do not satisfy a predicate. -/// -/// This stream is produced by the [`StreamExt::partition`] method. -pub struct FalsePartition { - partition: Arc>>, -} - -impl fmt::Debug for FalsePartition +pub struct Inner where St: Stream, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("FalsePartition").finish_non_exhaustive() - } + stream: St, + f: F, + buffered_value: Option>, + waker: Option, } -impl FalsePartition { - fn new(partition: Arc>>) -> Self { - Self { partition } +enum BufferedValue { + Match(T), + NonMatch(T), +} + +impl fmt::Debug for Inner +where + St: fmt::Debug + Stream, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Inner") + .field("stream", &self.stream) + .field("waker", &self.waker) + .finish_non_exhaustive() } } -impl Stream for FalsePartition +impl Inner where St: Stream + Unpin, F: FnMut(&St::Item) -> bool, { - type Item = St::Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut partition = self.partition.lock().unwrap(); - if let Some(item) = partition.false_buffer.pop_front() { - return Poll::Ready(Some(item)); + pub(super) fn new(stream: St, f: F) -> Self { + Self { + stream, + f, + buffered_value: None, + waker: None, } + } - match Pin::new(&mut partition.stream).poll_next(cx) { - Poll::Ready(None) => Poll::Ready(None), - Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)), - Poll::Ready(Some(item)) => { - partition.true_buffer.push_back(item); - partition.true_waker = Some(cx.waker().clone()); - cx.waker().wake_by_ref(); - Poll::Pending + fn poll_next(&mut self, cx: &mut Context<'_>, matches: bool) -> Poll> { + // Check if there is a buffered value + match self.buffered_value.take() { + Some(BufferedValue::Match(value)) if matches => return Poll::Ready(Some(value)), + Some(BufferedValue::NonMatch(value)) if !matches => return Poll::Ready(Some(value)), + Some(value) => { + self.buffered_value = Some(value); + self.waker = Some(cx.waker().clone()); + return Poll::Pending; } + None => {} + } + + // Poll the underlying stream + match Pin::new(&mut self.stream).poll_next(cx) { + Poll::Ready(Some(value)) => match (self.f)(&value) { + result if result == matches => Poll::Ready(Some(value)), + true => { + self.buffered_value = Some(BufferedValue::Match(value)); + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + false => { + self.buffered_value = Some(BufferedValue::NonMatch(value)); + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + }, + Poll::Ready(None) => Poll::Ready(None), // Stream is exhausted Poll::Pending => { - partition.false_waker = Some(cx.waker().clone()); + self.waker = Some(cx.waker().clone()); + cx.waker().wake_by_ref(); Poll::Pending } } diff --git a/tokio-stream/tests/stream_partition.rs b/tokio-stream/tests/stream_partition.rs index 96a74bf750f..b304704fedd 100644 --- a/tokio-stream/tests/stream_partition.rs +++ b/tokio-stream/tests/stream_partition.rs @@ -1,4 +1,5 @@ use tokio_stream::{self as stream, StreamExt}; +use tokio_test::{assert_pending, assert_ready_eq, task}; mod support { pub(crate) mod mpsc; @@ -6,24 +7,40 @@ mod support { #[tokio::test] async fn partition() { - let stream = stream::iter(0..4); - let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); - assert_eq!(Some(0), even.next().await); - assert_eq!(Some(1), odd.next().await); - assert_eq!(Some(2), even.next().await); - assert_eq!(Some(3), odd.next().await); - assert_eq!(None, even.next().await); - assert_eq!(None, odd.next().await); -} + let stream = stream::iter(0..6); + let (matches, non_matches) = stream.partition(|v| v % 2 == 0); + let mut matches = task::spawn(matches); + let mut non_matches = task::spawn(non_matches); -#[tokio::test] -async fn partition_buffers() { - let stream = stream::iter(0..4); - let (mut even, mut odd) = stream.partition(|v| v % 2 == 0); - assert_eq!(Some(1), odd.next().await); - assert_eq!(Some(3), odd.next().await); - assert_eq!(None, odd.next().await); - assert_eq!(Some(0), even.next().await); - assert_eq!(Some(2), even.next().await); - assert_eq!(None, odd.next().await); + // polling matches when the next item matches returns the item from the stream. + assert_ready_eq!(matches.poll_next(), Some(0)); + + // polling non_matches when the next item doesn't match returns the item from the stream. + assert_ready_eq!(non_matches.poll_next(), Some(1)); + + // polling non_matches when the next item matches buffers the item. + assert_pending!(non_matches.poll_next()); + + // polling matches when there is a bufferred match returns the buffered item. + assert_ready_eq!(matches.poll_next(), Some(2)); + + // polling matches when the next item doesn't match buffers the item. + assert_pending!(matches.poll_next()); + + // polling non_matches when there is a bufferred non-match returns the buffered item. + assert_ready_eq!(non_matches.poll_next(), Some(3)); + + // polling non_matches twice when the next item matches buffers the item only once. + assert_pending!(non_matches.poll_next()); + assert_pending!(non_matches.poll_next()); + assert_ready_eq!(matches.poll_next(), Some(4)); + + // polling matches twice when the next item doesn't match buffers the item only once. + assert_pending!(matches.poll_next()); + assert_pending!(matches.poll_next()); + assert_ready_eq!(non_matches.poll_next(), Some(5)); + + // polling matches and non_matches when the stream is exhausted returns None. + assert_ready_eq!(matches.poll_next(), None); + assert_ready_eq!(non_matches.poll_next(), None); }