diff --git a/core/shared/src/main/scala/fs2/Stream.scala b/core/shared/src/main/scala/fs2/Stream.scala index 9fd98ef5c4..8e9758141d 100644 --- a/core/shared/src/main/scala/fs2/Stream.scala +++ b/core/shared/src/main/scala/fs2/Stream.scala @@ -2015,6 +2015,224 @@ final class Stream[+F[_], +O] private[fs2] (private[fs2] val underlying: Pull[F, ): Stream[F2, O2] = that.mergeHaltL(this) + /** Merges two streams with priority given to the first stream. + * + * Internally, this uses two bounded queues (of one element). + * Queue has tryTake() which allows to try a non-blocking read. + * This is used to check for elements on the prioritized queue, + * before a blocking read through racePair() is tried on both + * queues, if no data is available on the prioritized queue. + */ + def mergePreferred[F2[x] >: F[x], O2 >: O]( + that: Stream[F2, O2] + )(implicit F: Concurrent[F2]): Stream[F2, O2] = { + val fstream: F2[Stream[F2, O2]] = + for { + interrupt <- F.deferred[Unit] + resultL <- F.deferred[Either[Throwable, Unit]] + resultR <- F.deferred[Either[Throwable, Unit]] + resultQL <- Queue.bounded[F2, Option[Stream[F2, O2]]](1) + resultQR <- Queue.bounded[F2, Option[Stream[F2, O2]]](1) + } yield { + + def watchInterrupted(str: Stream[F2, O2]): Stream[F2, O2] = + str.interruptWhen(interrupt.get.attempt) + + // action to signal that one stream is finished (by putting a None in it) + def doneAndClose(q: Queue[F2, Option[Stream[F2, O2]]]): F2[Unit] = q.offer(None).void + + // action to interrupt the processing of both streams by completing interrupt + val signalInterruption: F2[Unit] = interrupt.complete(()).void + + // Read from a stream and (possibly blocking) write to the bounded queue for that stream + def go(s: Stream[F2, O2], q: Queue[F2, Option[Stream[F2, O2]]]): Pull[F2, Nothing, Unit] = + s.pull.uncons + .flatMap { + case Some((hd, tl)) => + val send = q.offer(Some(Stream.chunk(hd))) + Pull.eval(send) >> go(tl, q) + case None => + Pull.done + } + + def runStream( + s: Stream[F2, O2], + whenDone: Deferred[F2, Either[Throwable, Unit]], + q: Queue[F2, Option[Stream[F2, O2]]] + ): F2[Unit] = { + val str = watchInterrupted(go(s, q).stream) + str.compile.drain.attempt + .flatMap { + // signal completion of our side before we will signal interruption, + // to make sure our result is always available to others + case r @ Left(_) => + whenDone.complete(r) >> signalInterruption + case r @ Right(_) => + whenDone.complete(r) >> doneAndClose(q) + } + } + + // Typedef for the fibres that read from the queues. + // That's contained in the Either returned by racePair() + type FBR = Fiber[F2, Throwable, Option[Stream[F2, O2]]] + + // An ADT for tracking state of the two queues. + // The types describe the state, starting with BothActive. + // Next state is either LeftDone or RightDone. + // Final state is BothDone. + // The members of those states store the loosing fibre + // of a racePair()-call, which will be reused during the + // next read. + sealed trait QueuesState + final case class BothActive(v: Option[Either[FBR, FBR]]) extends QueuesState + final case class LeftDone(rFbr: Option[FBR]) extends QueuesState + final case class RightDone(lFbr: Option[FBR]) extends QueuesState + case object BothDone extends QueuesState + + // Race the given effects, returning the result of the winner + // plus the still active fibre of the looser + def raceQueues( + lq: F2[Option[Stream[F2, O2]]], + rq: F2[Option[Stream[F2, O2]]] + ): F2[(Option[Stream[F2, O2]], Either[FBR, FBR])] = + F.racePair(lq, rq) + .flatMap { + case Left((result, fiber)) => + result.embedError.map(_ -> fiber.asRight[FBR]) + case Right((fiber, result)) => + result.embedError.map(_ -> fiber.asLeft[FBR]) + } + + // stream that is generated from pumping out the elements of the queue. + val pumpFromQueue: Stream[F2, O2] = + Stream + .unfoldEval[F2, QueuesState, Stream[F2, O2]](BothActive(None)) { s => + // Returning None from unfoldEval will stop the stream. If we read a None + // from any queue, we cannot return that but must continue reading on the + // other queue. Thus, we need a method which can be called recursively to + // continue reading in case of None. + def readNext(s: QueuesState): F2[(Option[Stream[F2, O2]], QueuesState)] = + s match { + // The initial state, both queues are active and there are no fibres left over + case BothActive(None) => + // check available data on left, which would be prioritized + resultQL.tryTake + .flatMap { + _.fold( + // no data available on prioritized queue, race both queues + raceQueues(resultQL.take, resultQR.take) + .flatMap[(Option[Stream[F2, O2]], QueuesState)] { + case (None, Left(fbr)) => + readNext(RightDone(fbr.some)) + case (None, Right(fbr)) => + readNext(LeftDone(fbr.some)) + case (Some(s), fbr) => + F.pure(s.some -> BothActive(fbr.some)) + } + )(os => + // we read data from the prioritized queue, however, this sill could be a None, + // signalling that queue is done. Handle that: + os.fold(readNext(LeftDone(None)))(ls => + F.pure(ls.some -> BothActive(None)) + ) + ) + } + + // right was looser during the last run + case BothActive(Some(Right(fbr))) => + // anyway, check for available data on left first, ignoring the incoming fibre for right + resultQL.tryTake + .flatMap( + _.fold( + // use the incoming fibre to read from right queue + raceQueues(resultQL.take, fbr.joinWithNever) + .flatMap[(Option[Stream[F2, O2]], QueuesState)] { + case (None, Left(fbr)) => + readNext(RightDone(fbr.some)) + case (None, Right(fbr)) => + readNext(LeftDone(fbr.some)) + case (Some(s), fbr) => + F.pure(s.some -> BothActive(fbr.some)) + } + )(os => + // important to reuse the incoming fibre here! + os.fold(readNext(LeftDone(fbr.some)))(ls => + F.pure(ls.some -> BothActive(fbr.asRight[FBR].some)) + ) + ) + ) + + // left was looser during the last run + case BothActive(Some(Left(fbr))) => + // Can't check for available data on left this time, + // because there's an active fibre reading from the left queue. + // Start a race and reuse that fibre for left. + raceQueues(fbr.joinWithNever, resultQR.take) + .flatMap[(Option[Stream[F2, O2]], QueuesState)] { + case (None, Left(fbr)) => + readNext(RightDone(fbr.some)) + case (None, Right(fbr)) => + readNext(LeftDone(fbr.some)) + case (Some(s), fbr) => + F.pure(s.some -> BothActive(fbr.some)) + } + + // Left queue is done, but, it's possible we retrieve an active fibre for right. + case LeftDone(fbr) => + fbr + .map(_.joinWithNever) // join the incoming fibre if given + .getOrElse(resultQR.take) // ordinary take() if no fibre has been given + .map { + case None => + None -> BothDone + case os => + os -> LeftDone(None) + } + + // mirror case of above + case RightDone(fbr) => + fbr + .map(_.joinWithNever) + .getOrElse(resultQL.take) + .map { + case None => + None -> BothDone + case os => + os -> RightDone(None) + } + + // this should never happen, but we need to make the compiler happy + case BothDone => + F.pure(None -> BothDone) + } + + // readNext() returns None in _1 if and only if both queues are done + readNext(s).map { + case (None, _) => + None // finish the stream (unfoldEval) + case (Some(s), st) => + (s -> st).some // emit element and new state (unfoldEval) + } + } + .flatten // we have Stream[F2, Stream[F2, O2]] and flatten that to Stream[F2, O2] + + val atRunEnd: F2[Unit] = + for { + _ <- signalInterruption // interrupt so the upstreams have chance to complete + left <- resultL.get + right <- resultR.get + r <- F.fromEither(CompositeFailure.fromResults(left, right)) + } yield r + + val runStreams = + runStream(this, resultL, resultQL).start >> runStream(that, resultR, resultQR).start + + Stream.bracket(runStreams)(_ => atRunEnd) >> watchInterrupted(pumpFromQueue) + } + Stream.eval(fstream).flatten + + } + /** Given two sorted streams emits a single sorted stream, like in merge-sort. * For entries that are considered equal by the Order, left stream element is emitted first. * Note: both this and another streams MUST BE ORDERED already diff --git a/core/shared/src/test/scala/fs2/StreamSuite.scala b/core/shared/src/test/scala/fs2/StreamSuite.scala index d1fb1ae82d..0e8a8131b3 100644 --- a/core/shared/src/test/scala/fs2/StreamSuite.scala +++ b/core/shared/src/test/scala/fs2/StreamSuite.scala @@ -401,6 +401,12 @@ class StreamSuite extends Fs2Suite { } } + test("mergePreferred") { + testCancelation { + constantStream.mergePreferred(constantStream) + } + } + test("parJoin") { testCancelation { Stream(constantStream, constantStream).parJoin(2)