Skip to content

Commit

Permalink
Change to single buffered value
Browse files Browse the repository at this point in the history
A larger buffer of values for each partition can be added on top of a
single buffered value by wrapping each stream with code that prefetches
the values. Thus this approach has a better composability than using an
internal buffer. The tradeoff is that this may deadlock and calling
code that needs to concurrently process items in both streams must be
programmed in a way that ensures that waiting to consume from one stream
never blocks the other stream.
  • Loading branch information
joshka committed Jan 4, 2025
1 parent 006b3d7 commit ffcf729
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 129 deletions.
4 changes: 2 additions & 2 deletions tokio-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ pub use stream_ext::{collect::FromStream, StreamExt};
/// Adapters for [`Stream`]s created by methods in [`StreamExt`].
pub mod adapters {
pub use crate::stream_ext::{
Chain, FalsePartition, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Peekable, Skip,
SkipWhile, Take, TakeWhile, Then, TruePartition,
Chain, Filter, FilterMap, Fuse, Map, MapWhile, Merge, Partition, Peekable, Skip, SkipWhile,
Take, TakeWhile, Then,
};
cfg_time! {
pub use crate::stream_ext::{ChunksTimeout, Timeout, TimeoutRepeating};
Expand Down
44 changes: 31 additions & 13 deletions tokio-stream/src/stream_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ mod next;
use next::Next;

mod partition;
use partition::Partition;
pub use partition::{FalsePartition, TruePartition};
pub use partition::Partition;

mod skip;
pub use skip::Skip;
Expand Down Expand Up @@ -847,8 +846,26 @@ pub trait StreamExt: Stream {

/// Partitions the stream into two streams based on the provided predicate.
///
/// The first stream contains all elements for which `f` returns `true`, and
/// the second contains all elements for which `f` returns `false`.
///
///
/// As values of this stream are made available, the provided predicate `f`
/// will be run against them. The first stream contains all elements for
/// which `f` returns `true`, and the second contains all elements for which
/// `f` returns `false`.
///
/// Note that this function consumes the stream passed into it and returns a
/// wrapped versions of it, similar to [`Iterator::partition`] method in the
/// standard library.
///
/// # Deadlocks
///
/// Polling the matching stream when the next value is not a match will not
/// succeed until the value is consumed by polling the non-matching stream.
/// Similarly, polling the non-matching stream when the next value is a
/// match will not succeed until the value is consumed by polling the
/// matching stream. This can lead to a deadlock if the streams are not
/// consumed in a way that allows the other stream to continue (e.g. by
/// using a task or tokio::select! to poll the streams concurrently).
///
/// # Examples
///
Expand All @@ -857,21 +874,22 @@ pub trait StreamExt: Stream {
/// # async fn main() {
/// use tokio_stream::{self as stream, StreamExt};
///
/// let s = stream::iter(vec![1, 2, 3, 4, 5]);
/// let s = stream::iter(0..4);
/// let (even, odd) = s.partition(|x| x % 2 == 0);
///
/// let even: Vec<_> = even.collect().await;
/// let odd: Vec<_> = odd.collect().await;
///
/// assert_eq!(even, vec![2, 4]);
/// assert_eq!(odd, vec![1, 3, 5]);
/// assert_eq!(even.next().await, Some(0));
/// assert_eq!(odd.next().await, Some(1));
/// assert_eq!(even.next().await, Some(2));
/// assert_eq!(odd.next().await, Some(3));
/// assert_eq!(even.next().await, None);
/// assert_eq!(odd.next().await, None);
/// # }
fn partition<F>(self, f: F) -> (TruePartition<Self, F>, FalsePartition<Self, F>)
fn partition<F>(self, f: F) -> (Partition<Self, F>, Partition<Self, F>)
where
F: FnMut(&Self::Item) -> bool,
Self: Sized,
Self: Sized + Unpin,
{
Partition::new(self, f).split()
Partition::new(self, f)
}

/// Drain stream pushing all emitted values into a collection.
Expand Down
180 changes: 85 additions & 95 deletions tokio-stream/src/stream_ext/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,150 +2,140 @@

use crate::Stream;

use core::fmt;
use std::{
collections::VecDeque,
fmt,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};

pub(super) struct Partition<St, F>
/// A stream returned by the [`partition`](super::StreamExt::partition) method.
pub enum Partition<St, F>
where
St: Stream,
{
stream: St,
f: F,
true_buffer: VecDeque<St::Item>,
false_buffer: VecDeque<St::Item>,
true_waker: Option<Waker>,
false_waker: Option<Waker>,
/// A stream that yields items for which the predicate returns `true`.
Matches(Arc<Mutex<Inner<St, F>>>),
/// A stream that yields items for which the predicate returns `false`.
NonMatches(Arc<Mutex<Inner<St, F>>>),
}

impl<St, F> Partition<St, F>
impl<St, F> fmt::Debug for Partition<St, F>
where
St: Stream,
{
pub(super) fn new(stream: St, f: F) -> Self {
Self {
stream,
f,
true_buffer: VecDeque::new(),
false_buffer: VecDeque::new(),
true_waker: None,
false_waker: None,
}
}

pub(super) fn split(self) -> (TruePartition<St, F>, FalsePartition<St, F>) {
let partition = Arc::new(Mutex::new(self));
let true_partition = TruePartition::new(partition.clone());
let false_partition = FalsePartition::new(partition.clone());
(true_partition, false_partition)
}
}

/// A stream that only yields elements that satisfy a predicate.
///
/// This stream is produced by the [`StreamExt::partition`] method.
pub struct TruePartition<St: Stream, F> {
partition: Arc<Mutex<Partition<St, F>>>,
}

impl<St, F> fmt::Debug for TruePartition<St, F>
where
St: Stream,
St: fmt::Debug + Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TruePartition").finish_non_exhaustive()
match self {
Partition::Matches(inner) => f.debug_tuple("Partition::Matches").field(inner).finish(),
Partition::NonMatches(inner) => {
f.debug_tuple("Partition::NonMatches").field(inner).finish()
}
}
}
}

impl<St, F> TruePartition<St, F>
impl<St, F> Partition<St, F>
where
St: Stream,
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
fn new(partition: Arc<Mutex<Partition<St, F>>>) -> Self {
Self { partition }
pub(super) fn new(stream: St, f: F) -> (Self, Self) {
let inner = Arc::new(Mutex::new(Inner::new(stream, f)));
(
Partition::Matches(inner.clone()),
Partition::NonMatches(inner),
)
}
}

impl<St, F> Stream for TruePartition<St, F>
impl<St, F> Stream for Partition<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut partition = self.partition.lock().unwrap();
if let Some(item) = partition.true_buffer.pop_front() {
return Poll::Ready(Some(item));
}

match Pin::new(&mut partition.stream).poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) if (partition.f)(&item) => Poll::Ready(Some(item)),
Poll::Ready(Some(item)) => {
partition.false_buffer.push_back(item);
partition.false_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => {
partition.true_waker = Some(cx.waker().clone());
Poll::Pending
}
match self.get_mut() {
Partition::Matches(inner) => inner.lock().unwrap().poll_next(cx, true),
Partition::NonMatches(inner) => inner.lock().unwrap().poll_next(cx, false),
}
}
}

/// A stream that only yields elements that do not satisfy a predicate.
///
/// This stream is produced by the [`StreamExt::partition`] method.
pub struct FalsePartition<St: Stream, F> {
partition: Arc<Mutex<Partition<St, F>>>,
}

impl<St, F> fmt::Debug for FalsePartition<St, F>
pub struct Inner<St, F>
where
St: Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FalsePartition").finish_non_exhaustive()
}
stream: St,
f: F,
buffered_value: Option<BufferedValue<St::Item>>,
waker: Option<Waker>,
}

impl<St: Stream, F> FalsePartition<St, F> {
fn new(partition: Arc<Mutex<Partition<St, F>>>) -> Self {
Self { partition }
enum BufferedValue<T> {
Match(T),
NonMatch(T),
}

impl<St, F> fmt::Debug for Inner<St, F>
where
St: fmt::Debug + Stream,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Inner")
.field("stream", &self.stream)
.field("waker", &self.waker)
.finish_non_exhaustive()
}
}

impl<St, F> Stream for FalsePartition<St, F>
impl<St, F> Inner<St, F>
where
St: Stream + Unpin,
F: FnMut(&St::Item) -> bool,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut partition = self.partition.lock().unwrap();
if let Some(item) = partition.false_buffer.pop_front() {
return Poll::Ready(Some(item));
pub(super) fn new(stream: St, f: F) -> Self {
Self {
stream,
f,
buffered_value: None,
waker: None,
}
}

match Pin::new(&mut partition.stream).poll_next(cx) {
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(item)) if !(partition.f)(&item) => Poll::Ready(Some(item)),
Poll::Ready(Some(item)) => {
partition.true_buffer.push_back(item);
partition.true_waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
fn poll_next(&mut self, cx: &mut Context<'_>, matches: bool) -> Poll<Option<St::Item>> {
// Check if there is a buffered value
match self.buffered_value.take() {
Some(BufferedValue::Match(value)) if matches => return Poll::Ready(Some(value)),
Some(BufferedValue::NonMatch(value)) if !matches => return Poll::Ready(Some(value)),
Some(value) => {
self.buffered_value = Some(value);
self.waker = Some(cx.waker().clone());
return Poll::Pending;
}
None => {}
}

// Poll the underlying stream
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Ready(Some(value)) => match (self.f)(&value) {
result if result == matches => Poll::Ready(Some(value)),
true => {
self.buffered_value = Some(BufferedValue::Match(value));
self.waker = Some(cx.waker().clone());
Poll::Pending
}
false => {
self.buffered_value = Some(BufferedValue::NonMatch(value));
self.waker = Some(cx.waker().clone());
Poll::Pending
}
},
Poll::Ready(None) => Poll::Ready(None), // Stream is exhausted
Poll::Pending => {
partition.false_waker = Some(cx.waker().clone());
self.waker = Some(cx.waker().clone());
cx.waker().wake_by_ref();
Poll::Pending
}
}
Expand Down
55 changes: 36 additions & 19 deletions tokio-stream/tests/stream_partition.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,46 @@
use tokio_stream::{self as stream, StreamExt};
use tokio_test::{assert_pending, assert_ready_eq, task};

mod support {
pub(crate) mod mpsc;
}

#[tokio::test]
async fn partition() {
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(1), odd.next().await);
assert_eq!(Some(2), even.next().await);
assert_eq!(Some(3), odd.next().await);
assert_eq!(None, even.next().await);
assert_eq!(None, odd.next().await);
}
let stream = stream::iter(0..6);
let (matches, non_matches) = stream.partition(|v| v % 2 == 0);
let mut matches = task::spawn(matches);
let mut non_matches = task::spawn(non_matches);

#[tokio::test]
async fn partition_buffers() {
let stream = stream::iter(0..4);
let (mut even, mut odd) = stream.partition(|v| v % 2 == 0);
assert_eq!(Some(1), odd.next().await);
assert_eq!(Some(3), odd.next().await);
assert_eq!(None, odd.next().await);
assert_eq!(Some(0), even.next().await);
assert_eq!(Some(2), even.next().await);
assert_eq!(None, odd.next().await);
// polling matches when the next item matches returns the item from the stream.
assert_ready_eq!(matches.poll_next(), Some(0));

// polling non_matches when the next item doesn't match returns the item from the stream.
assert_ready_eq!(non_matches.poll_next(), Some(1));

// polling non_matches when the next item matches buffers the item.
assert_pending!(non_matches.poll_next());

// polling matches when there is a bufferred match returns the buffered item.
assert_ready_eq!(matches.poll_next(), Some(2));

// polling matches when the next item doesn't match buffers the item.
assert_pending!(matches.poll_next());

// polling non_matches when there is a bufferred non-match returns the buffered item.
assert_ready_eq!(non_matches.poll_next(), Some(3));

// polling non_matches twice when the next item matches buffers the item only once.
assert_pending!(non_matches.poll_next());
assert_pending!(non_matches.poll_next());
assert_ready_eq!(matches.poll_next(), Some(4));

// polling matches twice when the next item doesn't match buffers the item only once.
assert_pending!(matches.poll_next());
assert_pending!(matches.poll_next());
assert_ready_eq!(non_matches.poll_next(), Some(5));

// polling matches and non_matches when the stream is exhausted returns None.
assert_ready_eq!(matches.poll_next(), None);
assert_ready_eq!(non_matches.poll_next(), None);
}

0 comments on commit ffcf729

Please sign in to comment.