Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Predicate stack-safe using Eval #283

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions bench/src/main/scala/PredicateBenchmarks.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package cats.collections
package bench

import org.openjdk.jmh.annotations.{Benchmark, Param, Scope, Setup, State}
import cats._
import cats.implicits._

@State(Scope.Benchmark)
class ChainedPredicateBench {
@Param(Array("10", "100", "1000", "10000"))
var n: Int = _

var pred: Predicate[Int] = _

@Setup
def setup: Unit = {
pred = Predicate(_ == 0)
pred = Iterator.iterate(pred.negate)(_ - pred).drop(n).next()
}

@Benchmark
def catsCollectionsPredicateUnravel: Unit = {
pred.contains(0)
}
}
113 changes: 103 additions & 10 deletions core/src/main/scala/cats/collections/Predicate.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
package cats.collections

import algebra.lattice.Bool
import cats._
import cats.data.Kleisli

/**
* An intensional set, which is a set which instead of enumerating its
* elements as a extensional set does, it is defined by a predicate
* which is a test for membership.
*
* All combinators in this class are implemented in a stack-safe way.
*/
abstract class Predicate[-A] extends scala.Function1[A, Boolean] { self =>
sealed abstract class Predicate[-A] extends scala.Function1[A, Boolean] { self =>
def run: Kleisli[Eval, A, Boolean]

/**
* Returns true if the value satisfies the predicate.
*/
def apply(a: A): Boolean = run(a).value

/**
* returns a predicate which is the union of this predicate and another
*/
def union[B <: A](other: Predicate[B]): Predicate[B] = Predicate(a => apply(a) || other(a))
def union[B <: A](other: Predicate[B]): Predicate[B] =
self match {
case Predicate.Empty => other
case Predicate.Everything => self
case _ =>
other match {
case Predicate.Empty => self
case Predicate.Everything => other
case _ => Predicate.Lift {
self.run.flatMap(if (_) Predicate.True else other.run)
}
}
}

/**
* returns a predicate which is the union of this predicate and another
Expand All @@ -21,7 +44,19 @@ abstract class Predicate[-A] extends scala.Function1[A, Boolean] { self =>
/**
* returns a predicate which is the intersection of this predicate and another
*/
def intersection[B <: A](other: Predicate[B]): Predicate[B] = Predicate(a => apply(a) && other(a))
def intersection[B <: A](other: Predicate[B]): Predicate[B] =
self match {
case Predicate.Empty => self
case Predicate.Everything => other
case _ =>
other match {
case Predicate.Empty => other
case Predicate.Everything => self
case _ => Predicate.Lift {
self.run.flatMap(if (_) other.run else Predicate.False)
}
}
}

/**
* returns a predicate which is the intersection of this predicate and another
Expand All @@ -36,7 +71,7 @@ abstract class Predicate[-A] extends scala.Function1[A, Boolean] { self =>
/**
* Returns the predicate which is the the difference of another predicate removed from this predicate
*/
def diff[B <: A](remove: Predicate[B]): Predicate[B] = Predicate(a => apply(a) && !remove(a))
def diff[B <: A](remove: Predicate[B]): Predicate[B] = self intersection remove.negate

/**
* Returns the predicate which is the the difference of another predicate removed from this predicate
Expand All @@ -46,28 +81,77 @@ abstract class Predicate[-A] extends scala.Function1[A, Boolean] { self =>
/**
* Return the opposite predicate
*/
def negate: Predicate[A] = Predicate(a => !apply(a))
def negate: Predicate[A]

/**
* Return the opposite predicate
*/
def unary_!(): Predicate[A] = negate

/**
* Compose the predicate with a function.
*
* A value is a member of the resulting predicate iff its image through f is a
* member of this predicate.
*/
def contramap[B](f: B => A): Predicate[B]

/**
* Alias for contramap.
*/
final override def compose[B](f: B => A): Predicate[B] = contramap(f)
}

object Predicate extends PredicateInstances {
def apply[A](f: A => Boolean): Predicate[A] = new Predicate[A] {
def apply(a: A) = f(a)
private val True = Kleisli.liftF(Eval.True)
private val False = Kleisli.liftF(Eval.False)

private[collections] case object Empty extends Predicate[Any] {
override def run: Kleisli[Eval, Any, Boolean] = Predicate.False
override def negate: Predicate[Any] = Everything
override def contramap[B](f: B => Any): Predicate[B] = this
}

def empty: Predicate[Any] = apply(_ => false)
private[collections] case object Everything extends Predicate[Any] {
override def run: Kleisli[Eval, Any, Boolean] = Predicate.True
override def negate: Predicate[Any] = Empty
override def contramap[B](f: B => Any): Predicate[B] = this
}

private[collections] final case class Lift[A](run: Kleisli[Eval, A, Boolean]) extends Predicate[A] {
override def negate: Predicate[A] = Negate(this)
override def contramap[B](f: B => A): Predicate[B] = Lift(run.compose(b => Eval.now(f(b))))
}

private[collections] final case class Negate[A](negate: Predicate[A]) extends Predicate[A] {
override def run: Kleisli[Eval, A, Boolean] = negate.run.map(!_)
override def contramap[B](f: B => A): Predicate[B] = Negate(negate contramap f)
}

/**
* build a set from a membership function.
*/
def apply[A](p: A => Boolean): Predicate[A] = Lift {
Kleisli(a => Eval.now(p(a)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a => if (p(a)) Eval.True else Eval.False may be marginally more efficient.

}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about a def fromKleisli[A](k: Kleisli[Eval, A, Boolean]): Predicate[A] = Lift(k)

/**
* The empty set: a predicate that rejects all values.
*/
def empty[A]: Predicate[A] = Empty

/**
* The set of everything: a predicate that accepts all values.
*/
def everything[A]: Predicate[A] = Everything
}

trait PredicateInstances {
implicit def predicateContravariantMonoidal: ContravariantMonoidal[Predicate] = new ContravariantMonoidal[Predicate] {
override def contramap[A, B](fb: Predicate[A])(f: B => A): Predicate[B] =
Predicate(f andThen fb.apply)
fb.contramap(f)
override def product[A, B](fa: Predicate[A], fb: Predicate[B]): Predicate[(A, B)] =
Predicate(v => fa(v._1) || fb(v._2))
fa.contramap[(A,B)](_._1) union fb.contramap(_._2)
override def unit: Predicate[Unit] = Predicate.empty
}

Expand All @@ -76,6 +160,15 @@ trait PredicateInstances {
override def combine(l: Predicate[A], r: Predicate[A]): Predicate[A] = l union r
}

implicit def predicateBool[A]: Bool[Predicate[A]] = new Bool[Predicate[A]] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually give longer names in cats to avoid the name aliasing problem with implicits.

e.g. implicit def catsCollectionsPredicateBool: Bool[Predicate[A]] = ...

override def one: Predicate[A] = Predicate.everything
override def zero: Predicate[A] = Predicate.empty
override def complement(x: Predicate[A]): Predicate[A] = x.negate
override def and(l: Predicate[A], r: Predicate[A]): Predicate[A] = l intersection r
override def or(l: Predicate[A], r: Predicate[A]): Predicate[A] = l union r

}

implicit val predicateMonoidK: MonoidK[Predicate] = new MonoidK[Predicate] {
override def empty[A]: Predicate[A] = Predicate.empty
override def combineK[A](l: Predicate[A], r: Predicate[A]): Predicate[A] = l union r
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/cats/collections/Set.scala
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,8 @@ object AvlSet extends AvlSetInstances {
private[collections] case object BTNil extends AvlSet[Nothing] {
override def isEmpty: Boolean = true

override def predicate(implicit order: cats.Order[Nothing]): Predicate[Nothing] = Predicate.empty

def apply[A](): AvlSet[A] = this.asInstanceOf[AvlSet[A]]

def unapply[A](a: AvlSet[A]): Boolean = a.isEmpty
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package cats.collections
package arbitrary

import org.scalacheck.{Gen, Arbitrary}
import cats.Order
import org.scalacheck.{Gen, Cogen, Arbitrary}

trait ArbitraryPredicate {
import set._
def predicateGen[A: Arbitrary: Order]: Gen[Predicate[A]] =
setGen.map(_.predicate)
import Gen._
def predicateGen[A: Arbitrary: Cogen]: Gen[Predicate[A]] =
oneOf(const(Predicate.empty), const(Predicate.everything), resultOf(Predicate(_: A => Boolean)))

implicit def predicateArbitrary[A: Arbitrary: Order]: Arbitrary[Predicate[A]] =
implicit def predicateArbitrary[A: Arbitrary: Cogen]: Arbitrary[Predicate[A]] =
Arbitrary(predicateGen[A])
}

39 changes: 39 additions & 0 deletions tests/src/test/scala/cats/collections/PredicateSpec.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cats.collections

package tests

import algebra.laws.LogicLaws
import cats.collections.arbitrary.predicate._
import cats.laws.discipline.{ContravariantMonoidalTests, SerializableTests}
import cats._
Expand All @@ -22,6 +24,16 @@ class PredicateSpec extends CatsSuite {
checkAll("ContravariantMonoidal[Predicate]", ContravariantMonoidalTests[Predicate].contravariantMonoidal[Int, Int, Int])
}

{
implicit val eqForPredicateInt: Eq[Predicate[Int]] = new Eq[Predicate[Int]] {
val sample = -1 to 1 // need at least 2 elements to distinguish in-between values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are just two enough? Why not (-100 to 100) or something?

Or, better yet, why not test with Predicate[Byte] and enumerate all 256 possibilities to check equality.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the domain is just one element, then every predicate is effectively equivalent to either Empty or Everything depending on whether it accepts the one element.

I don't think we need a 201 or 256 cardinality domain because I tried manually introducing some defects in Predicate and it caught them immediately with the 3 elements.

That being said, it doesn't negatively affect the performance of the testsuite on my machine at all so I'd be ok with either.

override def eqv(x: Predicate[Int], y: Predicate[Int]): Boolean =
sample.forall(a => x(a) == y(a))
}

checkAll("Bool[Predicate[Int]]", LogicLaws[Predicate[Int]].bool)
}

test("intersection works")(
forAll { (as: List[Int], bs: List[Int]) =>

Expand Down Expand Up @@ -72,4 +84,31 @@ class PredicateSpec extends CatsSuite {
bs.forall(b => (s1(b) != (as.contains(b) && (b % 2 != 0)))) should be(true)

})

{
def testStackSafety(name: String, deepSet: => Predicate[Int]) =
test(name) {
noException should be thrownBy {
deepSet.contains(0)
}
}
val Depth = 200000
val NonZero = Predicate[Int](_ != 0)
testStackSafety("union is stack safe on the left hand side",
Iterator.fill(Depth)(NonZero).reduceLeft(_ union _))
testStackSafety("union is stack safe on the right hand side",
Iterator.fill(Depth)(NonZero).reduceRight(_ union _))
testStackSafety("intersection is stack safe on the left hand side",
Iterator.fill(Depth)(!NonZero).reduceLeft(_ intersection _))
testStackSafety("intersection is stack safe on the right hand side",
Iterator.fill(Depth)(!NonZero).reduceRight(_ intersection _))
testStackSafety("negation is stack safe",
Iterator.iterate(NonZero)(_.negate).drop(Depth).next())
testStackSafety("contramap() is stack safe",
Iterator.iterate(NonZero)(_.contramap(identity _)).drop(Depth).next())
testStackSafety("diff is stack safe on the left hand side",
Iterator.fill(Depth)(!NonZero).reduceLeft(_ diff _))
testStackSafety("diff is stack safe on the right hand side",
Iterator.fill(Depth)(!NonZero).reduceRight(_ diff _))
}
}