diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index 41fb79db63..8df5621453 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -1599,6 +1599,13 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits with TuplePara def parTraverseN_[T[_]: Foldable, A, B](n: Int)(ta: T[A])(f: A => IO[B]): IO[Unit] = _asyncForIO.parTraverseN_(n)(ta)(f) + /** + * Like `Parallel.parFlatTraverse`, but limits the degree of parallelism. + */ + def parFlatTraverseN[T[_]: Traverse: cats.FlatMap, A, B](n: Int)(ta: T[A])( + f: A => IO[T[B]]): IO[T[B]] = + _asyncForIO.parFlatTraverseN(n)(ta)(f) + /** * Like `Parallel.parSequence` */ @@ -1623,6 +1630,12 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits with TuplePara def parSequenceN_[T[_]: Foldable, A](n: Int)(tma: T[IO[A]]): IO[Unit] = _asyncForIO.parSequenceN_(n)(tma) + /** + * Like `Parallel.parFlatSequence`, but limits the degree of parallelism. + */ + def parFlatSequenceN[T[_]: Traverse: cats.FlatMap, A](n: Int)(tmta: T[IO[T[A]]]): IO[T[A]] = + _asyncForIO.parFlatSequenceN(n)(tmta) + /** * Like `Parallel.parReplicateA`, but limits the degree of parallelism. */ diff --git a/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala b/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala index 024f640197..bddd3cbaf9 100644 --- a/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala +++ b/kernel/shared/src/main/scala/cats/effect/kernel/GenConcurrent.scala @@ -16,7 +16,7 @@ package cats.effect.kernel -import cats.{Foldable, Monoid, Semigroup, Traverse} +import cats.{FlatMap, Foldable, Monoid, Semigroup, Traverse} import cats.data.{EitherT, IorT, Kleisli, OptionT, WriterT} import cats.effect.kernel.instances.spawn._ import cats.effect.kernel.syntax.all._ @@ -155,6 +155,26 @@ trait GenConcurrent[F[_], E] extends GenSpawn[F, E] { MiniSemaphore[F](n).flatMap { sem => ta.parTraverse_ { a => sem.withPermit(f(a)) } } } + /** + * Like `Parallel.parFlatSequence`, but limits the degree of parallelism. + */ + def parFlatSequenceN[T[_]: Traverse: FlatMap, A](n: Int)(tma: T[F[T[A]]]): F[T[A]] = + parFlatTraverseN(n)(tma)(identity) + + /** + * Like `Parallel.parFlatTraverse`, but limits the degree of parallelism. Note that the + * semantics of this operation aim to maximise fairness: when a spot to execute becomes + * available, every task has a chance to claim it, and not only the next `n` tasks in `ta` + */ + def parFlatTraverseN[T[_]: Traverse: FlatMap, A, B](n: Int)(ta: T[A])( + f: A => F[T[B]]): F[T[B]] = { + require(n >= 1, s"Concurrency limit should be at least 1, was: $n") + + implicit val F: GenConcurrent[F, E] = this + + MiniSemaphore[F](n).flatMap { sem => ta.parFlatTraverse { a => sem.withPermit(f(a)) } } + } + override def racePair[A, B](fa: F[A], fb: F[B]) : F[Either[(Outcome[F, E, A], Fiber[F, E, B]), (Fiber[F, E, A], Outcome[F, E, B])]] = { implicit val F: GenConcurrent[F, E] = this diff --git a/kernel/shared/src/main/scala/cats/effect/kernel/syntax/GenConcurrentSyntax.scala b/kernel/shared/src/main/scala/cats/effect/kernel/syntax/GenConcurrentSyntax.scala index 35b7604aac..164d83822b 100644 --- a/kernel/shared/src/main/scala/cats/effect/kernel/syntax/GenConcurrentSyntax.scala +++ b/kernel/shared/src/main/scala/cats/effect/kernel/syntax/GenConcurrentSyntax.scala @@ -16,7 +16,7 @@ package cats.effect.kernel.syntax -import cats.{Foldable, Traverse} +import cats.{FlatMap, Foldable, Traverse} import cats.effect.kernel.GenConcurrent trait GenConcurrentSyntax { @@ -34,6 +34,11 @@ trait GenConcurrentSyntax { ): ConcurrentParSequenceNOps[T, F, A] = new ConcurrentParSequenceNOps(wrapped) + implicit def concurrentParFlatSequenceOps[T[_], F[_], A]( + wrapped: T[F[T[A]]] + ): ConcurrentParFlatSequenceNOps[T, F, A] = + new ConcurrentParFlatSequenceNOps(wrapped) + } final class GenConcurrentOps_[F[_], A] private[syntax] (private val wrapped: F[A]) @@ -57,6 +62,11 @@ final class ConcurrentParTraverseNOps[T[_], A] private[syntax] ( f: A => F[B] )(implicit T: Foldable[T], F: GenConcurrent[F, ?]): F[Unit] = F.parTraverseN_(n)(wrapped)(f) + + def parFlatTraverseN[F[_], B](n: Int)( + f: A => F[T[B]] + )(implicit T: Traverse[T], FM: FlatMap[T], F: GenConcurrent[F, ?]): F[T[B]] = + F.parFlatTraverseN(n)(wrapped)(f) } final class ConcurrentParSequenceNOps[T[_], F[_], A] private[syntax] ( @@ -68,3 +78,11 @@ final class ConcurrentParSequenceNOps[T[_], F[_], A] private[syntax] ( def parSequenceN_(n: Int)(implicit T: Foldable[T], F: GenConcurrent[F, ?]): F[Unit] = F.parSequenceN_(n)(wrapped) } + +final class ConcurrentParFlatSequenceNOps[T[_], F[_], A] private[syntax] ( + private val wrapped: T[F[T[A]]] +) extends AnyVal { + def parFlatSequenceN( + n: Int)(implicit T: Traverse[T], FM: FlatMap[T], F: GenConcurrent[F, ?]): F[T[A]] = + F.parFlatSequenceN(n)(wrapped) +} diff --git a/kernel/shared/src/test/scala/cats/effect/kernel/SyntaxSuite.scala b/kernel/shared/src/test/scala/cats/effect/kernel/SyntaxSuite.scala index 6f920fa5ea..38b5ec8db6 100644 --- a/kernel/shared/src/test/scala/cats/effect/kernel/SyntaxSuite.scala +++ b/kernel/shared/src/test/scala/cats/effect/kernel/SyntaxSuite.scala @@ -57,6 +57,16 @@ class SyntaxSuite { result: F[Unit] } + { + val result = List(target).parFlatTraverseN(3)(t => t.map(List(_))) + result: F[List[A]] + } + + { + val result = List(target.map(List(_))).parFlatSequenceN(3) + result: F[List[A]] + } + { val result = target.parReplicateAN(3)(5) result: F[List[A]] diff --git a/tests/shared/src/test/scala/cats/effect/IOSuite.scala b/tests/shared/src/test/scala/cats/effect/IOSuite.scala index 73dedf6129..abea7b43a6 100644 --- a/tests/shared/src/test/scala/cats/effect/IOSuite.scala +++ b/tests/shared/src/test/scala/cats/effect/IOSuite.scala @@ -1702,6 +1702,53 @@ class IOSuite extends BaseScalaCheckSuite with DisciplineSuite with IOPlatformSu assertCompleteAs(p, true) } + real("parFlatTraverseN - throw when n < 1") { + IO.defer { + List.empty[Int].parFlatTraverseN(0)(List(_).pure[IO]) + }.mustFailWith[IllegalArgumentException] + } + + real("parFlatTraverseN - propagate errors") { + List(1, 2, 3) + .parFlatTraverseN(2) { (n: Int) => + if (n == 2) IO.raiseError(new RuntimeException) else List(n).pure[IO] + } + .mustFailWith[RuntimeException] + } + + ticked("parFlatTraverseN - be cancelable") { implicit ticker => + val p = for { + f <- List(1, 2, 3).parFlatTraverseN(2)(_ => IO.never[List[Int]]).start + _ <- IO.sleep(100.millis) + _ <- f.cancel + } yield true + + assertCompleteAs(p, true) + } + + real("parFlatSequenceN - throw when n < 1") { + IO.defer { + List.empty[IO[List[Int]]].parFlatSequenceN(0) + }.mustFailWith[IllegalArgumentException] + } + + real("parFlatSequenceN - propagate errors") { + List(1, 2, 3) + .map { (n: Int) => if (n == 2) IO.raiseError(new RuntimeException) else List(n).pure[IO] } + .parFlatSequenceN(2) + .mustFailWith[RuntimeException] + } + + ticked("parFlatSequenceN - be cancelable") { implicit ticker => + val p = for { + f <- List(1, 2, 3).map(_ => IO.never[List[IO[Int]]]).parFlatSequenceN(2).start + _ <- IO.sleep(100.millis) + _ <- f.cancel + } yield true + + assertCompleteAs(p, true) + } + real("parallel - run parallel actually in parallel") { val x = IO.sleep(2.seconds) >> IO.pure(1) val y = IO.sleep(2.seconds) >> IO.pure(2)