Skip to content

Commit

Permalink
mergePreferred
Browse files Browse the repository at this point in the history
  • Loading branch information
shagoon committed Jun 2, 2023
1 parent 2efbede commit f52547a
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 0 deletions.
218 changes: 218 additions & 0 deletions core/shared/src/main/scala/fs2/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,224 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F,
): Stream[F2, O2] =
that.mergeHaltL(this)

/** Merges two streams with priority given to the first stream.
*
* Internally, this uses two bounded queues (of one element).
* Queue has tryTake() which allows to try a non-blocking read.
* This is used to check for elements on the prioritized queue,
* before a blocking read through racePair() is tried on both
* queues, if no data is available on the prioritized queue.
*/
def mergePreferred[F2[x] >: F[x], O2 >: O](
that: Stream[F2, O2]
)(implicit F: Concurrent[F2]): Stream[F2, O2] = {
val fstream: F2[Stream[F2, O2]] =
for {
interrupt <- F.deferred[Unit]
resultL <- F.deferred[Either[Throwable, Unit]]
resultR <- F.deferred[Either[Throwable, Unit]]
resultQL <- Queue.bounded[F2, Option[Stream[F2, O2]]](1)
resultQR <- Queue.bounded[F2, Option[Stream[F2, O2]]](1)
} yield {

def watchInterrupted(str: Stream[F2, O2]): Stream[F2, O2] =
str.interruptWhen(interrupt.get.attempt)

// action to signal that one stream is finished (by putting a None in it)
def doneAndClose(q: Queue[F2, Option[Stream[F2, O2]]]): F2[Unit] = q.offer(None).void

// action to interrupt the processing of both streams by completing interrupt
val signalInterruption: F2[Unit] = interrupt.complete(()).void

// Read from a stream and (possibly blocking) write to the bounded queue for that stream
def go(s: Stream[F2, O2], q: Queue[F2, Option[Stream[F2, O2]]]): Pull[F2, Nothing, Unit] =
s.pull.uncons
.flatMap {
case Some((hd, tl)) =>
val send = q.offer(Some(Stream.chunk(hd)))
Pull.eval(send) >> go(tl, q)
case None =>
Pull.done
}

def runStream(
s: Stream[F2, O2],
whenDone: Deferred[F2, Either[Throwable, Unit]],
q: Queue[F2, Option[Stream[F2, O2]]]
): F2[Unit] = {
val str = watchInterrupted(go(s, q).stream)
str.compile.drain.attempt
.flatMap {
// signal completion of our side before we will signal interruption,
// to make sure our result is always available to others
case r @ Left(_) =>
whenDone.complete(r) >> signalInterruption
case r @ Right(_) =>
whenDone.complete(r) >> doneAndClose(q)
}
}

// Typedef for the fibres that read from the queues.
// That's contained in the Either returned by racePair()
type FBR = Fiber[F2, Throwable, Option[Stream[F2, O2]]]

// An ADT for tracking state of the two queues.
// The types describe the state, starting with BothActive.
// Next state is either LeftDone or RightDone.
// Final state is BothDone.
// The members of those states store the loosing fibre
// of a racePair()-call, which will be reused during the
// next read.
sealed trait QueuesState
final case class BothActive(v: Option[Either[FBR, FBR]]) extends QueuesState
final case class LeftDone(rFbr: Option[FBR]) extends QueuesState
final case class RightDone(lFbr: Option[FBR]) extends QueuesState
final case object BothDone extends QueuesState

// Race the given effects, returning the result of the winner
// plus the still active fibre of the looser
def raceQueues(
lq: F2[Option[Stream[F2, O2]]],
rq: F2[Option[Stream[F2, O2]]]
): F2[(Option[Stream[F2, O2]], Either[FBR, FBR])] =
F.racePair(lq, rq)
.flatMap {
case Left((result, fiber)) =>
result.embedError.map(_ -> fiber.asRight[FBR])
case Right((fiber, result)) =>
result.embedError.map(_ -> fiber.asLeft[FBR])
}

// stream that is generated from pumping out the elements of the queue.
val pumpFromQueue: Stream[F2, O2] =
Stream
.unfoldEval[F2, QueuesState, Stream[F2, O2]](BothActive(None)) { s =>
// Returning None from unfoldEval will stop the stream. If we read a None
// from any queue, we cannot return that but must continue reading on the
// other queue. Thus, we need a method which can be called recursively to
// continue reading in case of None.
def readNext(s: QueuesState): F2[(Option[Stream[F2, O2]], QueuesState)] =
s match {
// The initial state, both queues are active and there are no fibres left over
case BothActive(None) =>
// check available data on left, which would be prioritized
resultQL.tryTake
.flatMap {
_.fold(
// no data available on prioritized queue, race both queues
raceQueues(resultQL.take, resultQR.take)
.flatMap[(Option[Stream[F2, O2]], QueuesState)] {
case (None, Left(fbr)) =>
readNext(RightDone(fbr.some))
case (None, Right(fbr)) =>
readNext(LeftDone(fbr.some))
case (Some(s), fbr) =>
F.pure(s.some -> BothActive(fbr.some))
}
)(os =>
// we read data from the prioritized queue, however, this sill could be a None,
// signalling that queue is done. Handle that:
os.fold(readNext(LeftDone(None)))(ls =>
F.pure(ls.some -> BothActive(None))
)
)
}

// right was looser during the last run
case BothActive(Some(Right(fbr))) =>
// anyway, check for available data on left first, ignoring the incoming fibre for right
resultQL.tryTake
.flatMap(
_.fold(
// use the incoming fibre to read from right queue
raceQueues(resultQL.take, fbr.joinWithNever)
.flatMap[(Option[Stream[F2, O2]], QueuesState)] {
case (None, Left(fbr)) =>
readNext(RightDone(fbr.some))
case (None, Right(fbr)) =>
readNext(LeftDone(fbr.some))
case (Some(s), fbr) =>
F.pure(s.some -> BothActive(fbr.some))
}
)(os =>
// important to reuse the incoming fibre here!
os.fold(readNext(LeftDone(fbr.some)))(ls =>
F.pure(ls.some -> BothActive(fbr.asRight[FBR].some))
)
)
)

// left was looser during the last run
case BothActive(Some(Left(fbr))) =>
// Can't check for available data on left this time,
// because there's an active fibre reading from the left queue.
// Start a race and reuse that fibre for left.
raceQueues(fbr.joinWithNever, resultQR.take)
.flatMap[(Option[Stream[F2, O2]], QueuesState)] {
case (None, Left(fbr)) =>
readNext(RightDone(fbr.some))
case (None, Right(fbr)) =>
readNext(LeftDone(fbr.some))
case (Some(s), fbr) =>
F.pure(s.some -> BothActive(fbr.some))
}

// Left queue is done, but, it's possible we retrieve an active fibre for right.
case LeftDone(fbr) =>
fbr
.map(_.joinWithNever) // join the incoming fibre if given
.getOrElse(resultQR.take) // ordinary take() if no fibre has been given
.map {
case None =>
None -> BothDone
case os =>
os -> LeftDone(None)
}

// mirror case of above
case RightDone(fbr) =>
fbr
.map(_.joinWithNever)
.getOrElse(resultQL.take)
.map {
case None =>
None -> BothDone
case os =>
os -> RightDone(None)
}

// this should never happen, but we need to make the compiler happy
case BothDone =>
F.pure(None -> BothDone)
}

// readNext() returns None in _1 if and only if both queues are done
readNext(s).map {
case (None, _) =>
None // finish the stream (unfoldEval)
case (Some(s), st) =>
(s -> st).some // emit element and new state (unfoldEval)
}
}
.flatten // we have Stream[F2, Stream[F2, O2]] and flatten that to Stream[F2, O2]

val atRunEnd: F2[Unit] =
for {
_ <- signalInterruption // interrupt so the upstreams have chance to complete
left <- resultL.get
right <- resultR.get
r <- F.fromEither(CompositeFailure.fromResults(left, right))
} yield r

val runStreams =
runStream(this, resultL, resultQL).start >> runStream(that, resultR, resultQR).start

Stream.bracket(runStreams)(_ => atRunEnd) >> watchInterrupted(pumpFromQueue)
}
Stream.eval(fstream).flatten

}

/** Given two sorted streams emits a single sorted stream, like in merge-sort.
* For entries that are considered equal by the Order, left stream element is emitted first.
* Note: both this and another streams MUST BE ORDERED already
Expand Down
6 changes: 6 additions & 0 deletions core/shared/src/test/scala/fs2/StreamSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,12 @@ class StreamSuite extends Fs2Suite {
}
}

test("mergePreferred") {
testCancelation {
constantStream.mergePreferred(constantStream)
}
}

test("parJoin") {
testCancelation {
Stream(constantStream, constantStream).parJoin(2)
Expand Down

0 comments on commit f52547a

Please sign in to comment.