Skip to content

Commit

Permalink
Merge pull request #4498 from TimWSpence/optimize-traverse
Browse files Browse the repository at this point in the history
Optimize traverse
  • Loading branch information
johnynek authored May 1, 2024
2 parents fa61d34 + 6cb787d commit ffb1df6
Show file tree
Hide file tree
Showing 18 changed files with 371 additions and 114 deletions.
87 changes: 87 additions & 0 deletions bench/src/main/scala/cats/bench/TraverseBench.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ package cats.bench
import cats.{Eval, Traverse, TraverseFilter}
import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, Setup, State}
import org.openjdk.jmh.infra.Blackhole
import cats.data.Chain

@State(Scope.Benchmark)
class TraverseBench {
val listT: Traverse[List] = Traverse[List]
val listTFilter: TraverseFilter[List] = TraverseFilter[List]
val chainTFilter: TraverseFilter[Chain] = TraverseFilter[Chain]

val vectorT: Traverse[Vector] = Traverse[Vector]
val vectorTFilter: TraverseFilter[Vector] = TraverseFilter[Vector]

val chainT: Traverse[Chain] = Traverse[Chain]

// the unit of CPU work per iteration
private[this] val Work: Long = 10

Expand All @@ -43,11 +47,13 @@ class TraverseBench {

var list: List[Int] = _
var vector: Vector[Int] = _
var chain: Chain[Int] = _

@Setup
def setup(): Unit = {
list = 0.until(length).toList
vector = 0.until(length).toVector
chain = Chain.fromSeq(0.until(length))
}

@Benchmark
Expand Down Expand Up @@ -83,6 +89,18 @@ class TraverseBench {
}
}

@Benchmark
def traverse_List(bh: Blackhole) = {
val result = listT.traverse_(list) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseFilterList(bh: Blackhole) = {
val result = listTFilter.traverseFilter(list) { i =>
Expand Down Expand Up @@ -137,6 +155,18 @@ class TraverseBench {
bh.consume(result.value)
}

@Benchmark
def traverse_Vector(bh: Blackhole) = {
val result = vectorT.traverse_(vector) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseVectorError(bh: Blackhole) = {
val result = vectorT.traverse(vector) { i =>
Expand Down Expand Up @@ -199,4 +229,61 @@ class TraverseBench {

bh.consume(results)
}

@Benchmark
def traverseChain(bh: Blackhole) = {
val result = chainT.traverse(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverse_Chain(bh: Blackhole) = {
val result = chainT.traverse_(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
i * 2
}
}

bh.consume(result.value)
}

@Benchmark
def traverseChainError(bh: Blackhole) = {
val result = chainT.traverse(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)

if (i == length * 0.3) {
throw Failure
}

i * 2
}
}

try {
bh.consume(result.value)
} catch {
case Failure => ()
}
}

@Benchmark
def traverseFilterChain(bh: Blackhole) = {
val result = chainTFilter.traverseFilter(chain) { i =>
Eval.later {
Blackhole.consumeCPU(Work)
if (i % 2 == 0) Some(i * 2) else None
}
}

bh.consume(result.value)
}
}
33 changes: 29 additions & 4 deletions core/src/main/scala-2.13+/cats/instances/arraySeq.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,24 @@ private[cats] object ArraySeqInstances {
B.combineAll(fa.iterator.map(f))

def traverse[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[ArraySeq[B]] =
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))
G match {
case x: StackSafeMonad[G] =>
x.map(Traverse.traverseDirectly(fa.iterator)(f)(x))(_.iterator.to(ArraySeq.untagged))
case _ =>
G.map(Chain.traverseViaChain(fa)(f))(_.iterator.to(ArraySeq.untagged))

}

override def traverse_[G[_], A, B](fa: ArraySeq[A])(f: A => G[B])(implicit G: Applicative[G]): G[Unit] =
G match {
case x: StackSafeMonad[G] => Traverse.traverse_Directly(fa)(f)(x)
case _ =>
foldRight(fa, Eval.now(G.unit)) { (a, acc) =>
G.map2Eval(f(a), acc) { (_, _) =>
()
}
}.value
}

override def mapAccumulate[S, A, B](init: S, fa: ArraySeq[A])(f: (S, A) => (S, B)): (S, ArraySeq[B]) =
StaticMethods.mapAccumulateFromStrictFunctor(init, fa, f)(this)
Expand Down Expand Up @@ -214,9 +231,17 @@ private[cats] object ArraySeqInstances {
def traverseFilter[G[_], A, B](
fa: ArraySeq[A]
)(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[ArraySeq[B]] =
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[B]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o))
}.value
G match {
case x: StackSafeMonad[G] =>
x.map(TraverseFilter.traverseFilterDirectly(fa.iterator)(f)(x))(
_.iterator.to(ArraySeq.untagged)
)
case _ =>
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[B]))) { case (x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o))
}.value

}

override def filterA[G[_], A](fa: ArraySeq[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[ArraySeq[A]] =
fa.foldRight(Eval.now(G.pure(ArraySeq.untagged.empty[A]))) { case (x, xse) =>
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/scala/cats/Traverse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

package cats

import cats.data.Chain
import cats.data.State
import cats.data.StateT
import cats.kernel.compat.scalaVersionSpecific._

/**
* Traverse, also known as Traversable.
Expand Down Expand Up @@ -284,4 +286,26 @@ object Traverse {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseOps

private[cats] def traverseDirectly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Chain[B]] = {
fa.iterator.foldLeft(G.pure(Chain.empty[B])) { case (accG, a) =>
G.map2(accG, f(a)) { case (acc, x) =>
acc :+ x
}
}
}

private[cats] def traverse_Directly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[B])(implicit G: StackSafeMonad[G]): G[Unit] = {
val iter = fa.iterator
if (iter.hasNext) {
val first = iter.next()
G.void(iter.foldLeft(f(first)) { case (g, a) =>
G.productR(g)(f(a))
})
} else G.unit
}

}
14 changes: 13 additions & 1 deletion core/src/main/scala/cats/TraverseFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

package cats

import cats.data.State
import cats.data.{Chain, State}
import cats.kernel.compat.scalaVersionSpecific._

import scala.collection.immutable.{IntMap, TreeSet}

Expand Down Expand Up @@ -203,4 +204,15 @@ object TraverseFilter {
@deprecated("Use cats.syntax object imports", "2.2.0")
object nonInheritedOps extends ToTraverseFilterOps

private[cats] def traverseFilterDirectly[G[_], A, B](
fa: IterableOnce[A]
)(f: A => G[Option[B]])(implicit G: StackSafeMonad[G]): G[Chain[B]] = {
fa.iterator.foldLeft(G.pure(Chain.empty[B])) { case (bldrG, a) =>
G.map2(bldrG, f(a)) {
case (acc, Some(b)) => acc :+ b
case (acc, None) => acc
}
}
}

}
43 changes: 32 additions & 11 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1243,11 +1243,27 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
def traverse[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
G match {
case x: StackSafeMonad[G] =>
Traverse.traverseDirectly(fa.iterator)(f)(x)
case _ =>
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
}

override def traverse_[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Unit] =
G match {
case x: StackSafeMonad[G] => Traverse.traverse_Directly(fa.iterator)(f)(x)
case _ =>
foldRight(fa, Eval.now(G.unit)) { (a, acc) =>
G.map2Eval(f(a), acc) { (_, _) =>
()
}
}.value
}

override def mapAccumulate[S, A, B](init: S, fa: Chain[A])(f: (S, A) => (S, B)): (S, Chain[B]) =
StaticMethods.mapAccumulateFromStrictFunctor(init, fa, f)(this)
Expand Down Expand Up @@ -1341,7 +1357,7 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
}

implicit val catsDataTraverseFilterForChain: TraverseFilter[Chain] = new TraverseFilter[Chain] {
def traverse: Traverse[Chain] = Chain.catsDataInstancesForChain
def traverse: Traverse[Chain] with Alternative[Chain] = Chain.catsDataInstancesForChain

override def filter[A](fa: Chain[A])(f: A => Boolean): Chain[A] = fa.filter(f)

Expand All @@ -1356,11 +1372,16 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {
def traverseFilter[G[_], A, B](fa: Chain[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
G match {
case x: StackSafeMonad[G] =>
TraverseFilter.traverseFilterDirectly(fa.iterator)(f)(x)
case _ =>
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
KernelStaticMethods.wrapMutableIndexedSeq(as)
}(f)
}

override def filterA[G[_], A](fa: Chain[A])(f: A => G[Boolean])(implicit G: Applicative[G]): G[Chain[A]] =
traverse
Expand Down
Loading

0 comments on commit ffb1df6

Please sign in to comment.