Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stream: add StreamExt::partition() method #7065

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt};
/// Adapters for [`Stream`]s created by methods in [`StreamExt`].
pub mod adapters {
pub use crate::stream_ext::{
Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take,
TakeWhile, Then,
Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Partition, Peekable, Skip, SkipWhile,
Take, TakeWhile, Then,
};
cfg_time! {
pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating};
Expand Down
51 changes: 51 additions & 0 deletions tokio-stream/src/stream_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub use merge::Merge;
mod next;
use next::Next;

mod partition;
pub use partition::Partition;

mod skip;
pub use skip::Skip;

Expand Down Expand Up @@ -841,6 +844,54 @@ pub trait StreamExt: Stream {
FoldFuture::new(self, init, f)
}

/// Partitions the stream into two streams based on the provided predicate.
///
///
///
/// As values of this stream are made available, the provided predicate `f`
/// will be run against them. The first stream contains all elements for
/// which `f` returns `true`, and the second contains all elements for which
/// `f` returns `false`.
///
/// Note that this function consumes the stream passed into it and returns a
/// wrapped versions of it, similar to [`Iterator::partition`] method in the
/// standard library.
///
/// # Deadlocks
///
/// Polling the matching stream when the next value is not a match will not
/// succeed until the value is consumed by polling the non-matching stream.
/// Similarly, polling the non-matching stream when the next value is a
/// match will not succeed until the value is consumed by polling the
/// matching stream. This can lead to a deadlock if the streams are not
/// consumed in a way that allows the other stream to continue (e.g. by
/// using a task or tokio::select! to poll the streams concurrently).
///
/// # Examples
///
/// ```
/// # #[tokio::main]
/// # async fn main() {
/// use tokio_stream::{self as stream, StreamExt};
///
/// let s = stream::iter(0..4);
/// let (even, odd) = s.partition(|x| x % 2 == 0);
///
/// assert_eq!(even.next().await, Some(0));
/// assert_eq!(odd.next().await, Some(1));
/// assert_eq!(even.next().await, Some(2));
/// assert_eq!(odd.next().await, Some(3));
/// assert_eq!(even.next().await, None);
/// assert_eq!(odd.next().await, None);
/// # }
fn partition<F>(self, f: F) -> (Partition<Self, F>, Partition<Self, F>)
where
F: FnMut(&Self::Item) -> bool,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should return a future instead of a bool. The futures StreamExt has methods which tend to do so, but the tokio StreamExt does not. I kept this consistent with tokio here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine. People can embed async computation in the stream itself if they really need it.

Self: Sized + Unpin,
{
Partition::new(self, f)
}

/// Drain stream pushing all emitted values into a collection.
///
/// Equivalent to:
Expand Down
143 changes: 143 additions & 0 deletions tokio-stream/src/stream_ext/partition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#![allow(dead_code)]

use crate::Stream;

use std::{
fmt,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};

/// A stream returned by the [`partition`](super::StreamExt::partition) method.
pub enum Partition<St, F>
where
St: Stream,
{
/// A stream that yields items for which the predicate returns `true`.
Matches(Arc<Mutex<Inner<St, F>>>),
/// A stream that yields items for which the predicate returns `false`.
NonMatches(Arc<Mutex<Inner<St, F>>>),
Comment on lines +17 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a const generic boolean.

}

impl<St, F> fmt::Debug for Partition<St, F>
where
St: fmt::Debug + Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Partition::Matches(inner) => f.debug_tuple("Partition::Matches").field(inner).finish(),
Partition::NonMatches(inner) => {
f.debug_tuple("Partition::NonMatches").field(inner).finish()
}
}
}
}

impl<St, F> Partition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
pub(super) fn new(stream: St, f: F) -> (Self, Self) {
let inner = Arc::new(Mutex::new(Inner::new(stream, f)));
(
Partition::Matches(inner.clone()),
Partition::NonMatches(inner),
)
}
}

impl<St, F> Stream for Partition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
Partition::Matches(inner) => inner.lock().unwrap().poll_next(cx, true),
Partition::NonMatches(inner) => inner.lock().unwrap().poll_next(cx, false),
}
}
}

pub struct Inner<St, F>
where
St: Stream,
{
stream: St,
f: F,
buffered_value: Option<BufferedValue<St::Item>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you go for allowing a user-specified limit, then you can let this be an Vec<St::Item> and have a boolean next to it that specifies which of the two streams the buffered values are for.

waker: Option<Waker>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be nicer as just a Waker perhaps, using Waker::noop to make the default. But that's a rust 1.85 feature, which is not yet released (and way past the current MSRV). I considered AtomicWaker here too, but I didn't want to add futures-util as a dependency of tokio-stream.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you use it? I see only an assignment self.waker=Some(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using Option is perfectly fine.

}

enum BufferedValue<T> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally modeled this as two opts (match_value, non_match_value), but it leads to simpler code using an enum here I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you make both checks into a boolean, then you can do buffered_type == this_stream to check if the buffered value should be used or not.

Match(T),
NonMatch(T),
}

impl<St, F> fmt::Debug for Inner<St, F>
where
St: fmt::Debug + Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Inner")
.field("stream", &self.stream)
.field("waker", &self.waker)
.finish_non_exhaustive()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure about what to do about the buffered value. I don't think I can write two implementations of Debug, one where the Item doesn't implement Debug, and one where it does. I'm not sure what scenarios you'd implement Debug on a stream impl of items that aren't Debug, but I'm not sure there's a way to express that in the type system neatly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just require St::Item: Debug.

}
}

impl<St, F> Inner<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
pub(super) fn new(stream: St, f: F) -> Self {
Self {
stream,
f,
buffered_value: None,
waker: None,
}
}

fn poll_next(&mut self, cx: &mut Context<'_>, matches: bool) -> Poll<Option<St::Item>> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually I dislike boolean params, but I think this one makes sense as an exception. Not sure there's a good concise replacement idea for this.

// Check if there is a buffered value
match self.buffered_value.take() {
Some(BufferedValue::Match(value)) if matches => return Poll::Ready(Some(value)),
Some(BufferedValue::NonMatch(value)) if !matches => return Poll::Ready(Some(value)),
Some(value) => {
self.buffered_value = Some(value);
self.waker = Some(cx.waker().clone());
return Poll::Pending;
}
None => {}
}

// Poll the underlying stream
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(value)) => match (self.f)(&value) {
result if result == matches => Poll::Ready(Some(value)),
true => {
self.buffered_value = Some(BufferedValue::Match(value));
self.waker = Some(cx.waker().clone());
Poll::Pending
}
false => {
self.buffered_value = Some(BufferedValue::NonMatch(value));
self.waker = Some(cx.waker().clone());
Poll::Pending
}
},
Poll::Ready(None) => Poll::Ready(None), // Stream is exhausted
Poll::Pending => {
self.waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
46 changes: 46 additions & 0 deletions tokio-stream/tests/stream_partition.rs
joshka marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use tokio_stream::{self as stream, StreamExt};
use tokio_test::{assert_pending, assert_ready_eq, task};

mod support {
pub(crate) mod mpsc;
}

#[tokio::test]
async fn partition() {
let stream = stream::iter(0..6);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only tests a situation where the items are ready. A test for when the values are not ready should be added.

let (matches, non_matches) = stream.partition(|v| v % 2 == 0);
let mut matches = task::spawn(matches);
let mut non_matches = task::spawn(non_matches);

// polling matches when the next item matches returns the item from the stream.
assert_ready_eq!(matches.poll_next(), Some(0));

// polling non_matches when the next item doesn't match returns the item from the stream.
assert_ready_eq!(non_matches.poll_next(), Some(1));

// polling non_matches when the next item matches buffers the item.
assert_pending!(non_matches.poll_next());

// polling matches when there is a bufferred match returns the buffered item.
assert_ready_eq!(matches.poll_next(), Some(2));

// polling matches when the next item doesn't match buffers the item.
assert_pending!(matches.poll_next());

// polling non_matches when there is a bufferred non-match returns the buffered item.
assert_ready_eq!(non_matches.poll_next(), Some(3));

// polling non_matches twice when the next item matches buffers the item only once.
assert_pending!(non_matches.poll_next());
assert_pending!(non_matches.poll_next());
assert_ready_eq!(matches.poll_next(), Some(4));

// polling matches twice when the next item doesn't match buffers the item only once.
assert_pending!(matches.poll_next());
assert_pending!(matches.poll_next());
assert_ready_eq!(non_matches.poll_next(), Some(5));

// polling matches and non_matches when the stream is exhausted returns None.
assert_ready_eq!(matches.poll_next(), None);
assert_ready_eq!(non_matches.poll_next(), None);
}
Loading