Skip to content

Commit

Permalink
stream: add StreamExt::partition() method
Browse files Browse the repository at this point in the history
This allows filtering items that match or don't match into separate
stream. It is analogous to std::iter::Iterator::partition()

```rust
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(1), odd.next().await);
```
  • Loading branch information
joshka committed Jan 3, 2025
1 parent e066431 commit 006b3d7
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt};
/// Adapters for [`Stream`]s created by methods in [`StreamExt`].
pub mod adapters {
pub use crate::stream_ext::{
Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip, SkipWhile, Take,
TakeWhile, Then,
Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip,
SkipWhile, Take, TakeWhile, Then, TruePartition,
};
cfg_time! {
pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating};
Expand Down
33 changes: 33 additions & 0 deletions tokio-stream/src/stream_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ pub use merge::Merge;
mod next;
use next::Next;

mod partition;
use partition::Partition;
pub use partition::{FalsePartition, TruePartition};

mod skip;
pub use skip::Skip;

Expand Down Expand Up @@ -841,6 +845,35 @@ pub trait StreamExt: Stream {
FoldFuture::new(self, init, f)
}

/// Partitions the stream into two streams based on the provided predicate.
///
/// The first stream contains all elements for which `f` returns `true`, and
/// the second contains all elements for which `f` returns `false`.
///
/// # Examples
///
/// ```
/// # #[tokio::main]
/// # async fn main() {
/// use tokio_stream::{self as stream, StreamExt};
///
/// let s = stream::iter(vec![1, 2, 3, 4, 5]);
/// let (even, odd) = s.partition(|x| x % 2 == 0);
///
/// let even: Vec<_> = even.collect().await;
/// let odd: Vec<_> = odd.collect().await;
///
/// assert_eq!(even, vec![2, 4]);
/// assert_eq!(odd, vec![1, 3, 5]);
/// # }
fn partition<F>(self, f: F) -> (TruePartition<Self, F>, FalsePartition<Self, F>)
where
F: FnMut(&Self::Item) -> bool,
Self: Sized,
{
Partition::new(self, f).split()
}

/// Drain stream pushing all emitted values into a collection.
///
/// Equivalent to:
Expand Down
153 changes: 153 additions & 0 deletions tokio-stream/src/stream_ext/partition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#![allow(dead_code)]

use crate::Stream;

use core::fmt;
use std::{
collections::VecDeque,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};

pub(super) struct Partition<St, F>
where
St: Stream,
{
stream: St,
f: F,
true_buffer: VecDeque<St::Item>,
false_buffer: VecDeque<St::Item>,
true_waker: Option<Waker>,
false_waker: Option<Waker>,
}

impl<St, F> Partition<St, F>
where
St: Stream,
{
pub(super) fn new(stream: St, f: F) -> Self {
Self {
stream,
f,
true_buffer: VecDeque::new(),
false_buffer: VecDeque::new(),
true_waker: None,
false_waker: None,
}
}

pub(super) fn split(self) -> (TruePartition<St, F>, FalsePartition<St, F>) {
let partition = Arc::new(Mutex::new(self));
let true_partition = TruePartition::new(partition.clone());
let false_partition = FalsePartition::new(partition.clone());
(true_partition, false_partition)
}
}

/// A stream that only yields elements that satisfy a predicate.
///
/// This stream is produced by the [`StreamExt::partition`] method.
pub struct TruePartition<St: Stream, F> {
partition: Arc<Mutex<Partition<St, F>>>,
}

impl<St, F> fmt::Debug for TruePartition<St, F>
where
St: Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TruePartition").finish_non_exhaustive()
}
}

impl<St, F> TruePartition<St, F>
where
St: Stream,
{
fn new(partition: Arc<Mutex<Partition<St, F>>>) -> Self {
Self { partition }
}
}

impl<St, F> Stream for TruePartition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut partition = self.partition.lock().unwrap();
if let Some(item) = partition.true_buffer.pop_front() {
return Poll::Ready(Some(item));
}

match Pin::new(&mut partition.stream).poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)),
Poll::Ready(Some(item)) => {
partition.false_buffer.push_back(item);
partition.false_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => {
partition.true_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}

/// A stream that only yields elements that do not satisfy a predicate.
///
/// This stream is produced by the [`StreamExt::partition`] method.
pub struct FalsePartition<St: Stream, F> {
partition: Arc<Mutex<Partition<St, F>>>,
}

impl<St, F> fmt::Debug for FalsePartition<St, F>
where
St: Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FalsePartition").finish_non_exhaustive()
}
}

impl<St: Stream, F> FalsePartition<St, F> {
fn new(partition: Arc<Mutex<Partition<St, F>>>) -> Self {
Self { partition }
}
}

impl<St, F> Stream for FalsePartition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut partition = self.partition.lock().unwrap();
if let Some(item) = partition.false_buffer.pop_front() {
return Poll::Ready(Some(item));
}

match Pin::new(&mut partition.stream).poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)),
Poll::Ready(Some(item)) => {
partition.true_buffer.push_back(item);
partition.true_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => {
partition.false_waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
29 changes: 29 additions & 0 deletions tokio-stream/tests/stream_partition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use tokio_stream::{self as stream, StreamExt};

mod support {
pub(crate) mod mpsc;
}

#[tokio::test]
async fn partition() {
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(1), odd.next().await);
assert_eq!(Some(2), even.next().await);
assert_eq!(Some(3), odd.next().await);
assert_eq!(None, even.next().await);
assert_eq!(None, odd.next().await);
}

#[tokio::test]
async fn partition_buffers() {
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(1), odd.next().await);
assert_eq!(Some(3), odd.next().await);
assert_eq!(None, odd.next().await);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(2), even.next().await);
assert_eq!(None, odd.next().await);
}

0 comments on commit 006b3d7

Please sign in to comment.