Skip to content

Commit

Permalink
Revert "Stratified aggregation (#37)"
Browse files Browse the repository at this point in the history
This reverts commit d655b18.
  • Loading branch information
aherlihy committed Dec 14, 2023
1 parent d655b18 commit 108ea62
Show file tree
Hide file tree
Showing 44 changed files with 68 additions and 829 deletions.
32 changes: 0 additions & 32 deletions src/main/scala/datalog/dsl/DSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,3 @@ case class Relation[T <: Constant](id: Int, name: String)(using ee: ExecutionEng
def solve(): Set[Seq[Term]] = ee.solve(id).map(s => s.toSeq).toSet
def get(): Set[Seq[Term]] = ee.get(id)
}


enum AggOp(val t: Term):
case SUM(override val t: Term) extends AggOp(t)
case COUNT(override val t: Term) extends AggOp(t)
case MIN(override val t: Term) extends AggOp(t)
case MAX(override val t: Term) extends AggOp(t)

case class GroupingAtom(gp: Atom, gv: Seq[Variable], ags: Seq[(AggOp, Variable)])
extends Atom(gp.rId, gv ++ ags.map(_._2), false):
// We set the relation id of the grouping predicate because the 'virtual' relation will be computed from it and also because we need it to be so for certain logic: dep in JoinIndexes, node id in DependencyGraph, etc.
override val hash: String = s"GB${gp.hash}-${gv.mkString("", "", "")}-${ags.mkString("", "", "")}"

object groupBy:
def apply(gp: Atom, gv: Seq[Variable], ags: (AggOp, Variable)*): GroupingAtom =
if (gp.negated)
throw new Exception("The grouping predicate cannot be negated")
if (gv.size != gv.distinct.size)
throw new Exception("The grouping variables cannot be repeated")
if (ags.map(_._2).size != ags.map(_._2).distinct.size)
throw new Exception("The aggregation variables cannot be repeated")
val gpVars = gp.terms.collect{ case v: Variable => v }.toSet
val gVars = gv.toSet
val aggVars = ags.map(_._2).toSet
val aggdVars = ags.map(_._1.t).collect{ case v: Variable => v }.toSet
if (gVars.contains(__) || aggVars.contains(__) || aggdVars.contains(__))
throw new Exception("Anonymous variable ('__') not allowed as a grouping variable, aggregation variable or aggregated variable")
if (aggVars.intersect(gpVars).nonEmpty)
throw new Exception("No aggregation variable must not occur in the grouping predicate")
if (!(aggdVars ++ gVars).subsetOf(gpVars))
throw new Exception("The aggregated variables and the grouping variables must occurr in the grouping predicate")
GroupingAtom(gp, gv, ags)
6 changes: 0 additions & 6 deletions src/main/scala/datalog/execution/BytecodeCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,6 @@ class BytecodeCompiler(val storageManager: StorageManager)(using JITOptions) ext
traverse(xb, children(1))
emitSMCall(xb, "diff", classOf[EDB], classOf[EDB])

case GroupingOp(child, gji) =>
xb.aload(0)
traverse(xb, child)
emitGroupingJoinIndexes(xb, gji)
emitSMCall(xb, "groupingHelper", classOf[EDB], classOf[GroupingJoinIndexes])

case DebugPeek(prefix, msg, children: _*) =>
assert(false, s"Unimplemented node: $irTree")

Expand Down
49 changes: 6 additions & 43 deletions src/main/scala/datalog/execution/BytecodeGenerator.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package datalog.execution

import datalog.dsl.{Atom, Constant, Term, Variable}
import datalog.execution.AggOpIndex
import datalog.storage.{RelationId, StorageAggOp}
import datalog.storage.RelationId

import java.lang.constant.ConstantDescs.*
import java.lang.constant.*
Expand Down Expand Up @@ -193,11 +192,11 @@ object BytecodeGenerator {
case e: Int => emitInteger(xb, e)
case s: String => emitString(xb, s)

def emitStringConstantTuple2(xb: CodeBuilder, t: (String, Constant)): Unit =
emitNew(xb, classOf[(String, Constant)], xxb =>
emitString(xxb, t._1)
emitConstant(xxb, t._2)
)
def emitStringConstantTuple2(xb: CodeBuilder, t: (String, Constant)): Unit =
emitNew(xb, classOf[(String, Constant)], xxb =>
emitString(xxb, t._1)
emitConstant(xxb, t._2)
)

def emitVariable(xb: CodeBuilder, variable: Variable): Unit =
emitNew(xb, classOf[Variable], xxb =>
Expand Down Expand Up @@ -279,42 +278,6 @@ object BytecodeGenerator {
emitCxns(xxb, value.cxns)
emitBool(xxb, value.edb))

def emitStorageAggOp(xb: CodeBuilder, sao: StorageAggOp): Unit =
val enumCompanionCls = classOf[StorageAggOp.type]
emitObject(xb, enumCompanionCls)
xb.constantInstruction(sao.ordinal)
emitCall(xb, enumCompanionCls, "fromOrdinal", classOf[Int])

def emitAggOpIndex(xb: CodeBuilder, aoi: AggOpIndex): Unit = aoi match {
case gv: AggOpIndex.GV =>
emitNew(xb, classOf[AggOpIndex.GV], xxb =>
emitInteger(xxb, gv.i)
)
case lv: AggOpIndex.LV =>
emitNew(xb, classOf[AggOpIndex.GV], xxb =>
emitInteger(xxb, lv.i)
)
case c: AggOpIndex.C =>
emitNew(xb, classOf[AggOpIndex.C], xxb =>
emitConstant(xxb, c.c)
)
}

def emitAggOpInfos(xb: CodeBuilder, value: Seq[(StorageAggOp, AggOpIndex)]): Unit =
emitSeq(xb, value.map(v => xxb =>
emitNew(xb, classOf[(StorageAggOp, AggOpIndex)], xxxb =>
emitStorageAggOp(xxxb, v._1)
emitAggOpIndex(xxxb, v._2)
)
))

def emitGroupingJoinIndexes(xb: CodeBuilder, value: GroupingJoinIndexes): Unit =
emitNew(xb, classOf[GroupingJoinIndexes], xxb =>
emitVarIndexes(xxb, value.varIndexes)
emitConstIndexes(xxb, value.constIndexes)
emitSeqInt(xxb, value.groupingIndexes)
emitAggOpInfos(xxb, value.aggOpInfos))

val CD_BoxedUnit = clsDesc(classOf[scala.runtime.BoxedUnit])

/** Emit `BoxedUnit.UNIT`. */
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/datalog/execution/ExecutionEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait ExecutionEngine {
* @param rule - Includes the head at idx 0
*/
inline def getOperatorKey(rule: Seq[Atom]): JoinIndexes =
JoinIndexes(rule, None, None)
JoinIndexes(rule, None)

def getOperatorKeys(rId: RelationId): mutable.ArrayBuffer[JoinIndexes] = {
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]())
Expand Down
69 changes: 11 additions & 58 deletions src/main/scala/datalog/execution/JoinIndexes.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package datalog.execution

import datalog.dsl.{Atom, Constant, Variable, GroupingAtom, AggOp}
import datalog.dsl.{Atom, Constant, Variable}
import datalog.execution.ir.{IROp, ProjectJoinFilterOp, ScanOp}
import datalog.storage.{DB, EDB, NS, RelationId, StorageManager, StorageAggOp}
import datalog.storage.{DB, EDB, NS, RelationId, StorageManager}
import datalog.tools.Debug.debug

import scala.collection.mutable
Expand All @@ -12,19 +12,7 @@ import scala.reflect.ClassTag
type AllIndexes = mutable.Map[String, JoinIndexes]

enum PredicateType:
case POSITIVE, NEGATED, GROUPING


enum AggOpIndex:
case LV(i: Int)
case GV(i: Int)
case C(c: Constant)

case class GroupingJoinIndexes(varIndexes: Seq[Seq[Int]],
constIndexes: mutable.Map[Int, Constant],
groupingIndexes: Seq[Int],
aggOpInfos: Seq[(StorageAggOp, AggOpIndex)]
)
case POSITIVE, NEGATED

/**
* Wrapper object for join keys for IDB rules
Expand All @@ -43,8 +31,7 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
deps: Seq[(PredicateType, RelationId)],
atoms: Seq[Atom],
cxns: mutable.Map[String, mutable.Map[Int, Seq[String]]],
edb: Boolean = false,
groupingIndexes: Map[String, GroupingJoinIndexes] = Map.empty
edb: Boolean = false
) {
override def toString(): String = ""//toStringWithNS(null)

Expand All @@ -70,18 +57,13 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
}

object JoinIndexes {
def apply(rule: Seq[Atom], precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]],
precalculatedGroupingIndexes: Option[Map[String, GroupingJoinIndexes]]) = {
def apply(rule: Seq[Atom], precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]]) = {
val constants = mutable.Map[Int, Constant]() // position => constant
val variables = mutable.Map[Variable, Int]() // v.oid => position

val body = rule.drop(1)

val deps = body.map(a => (
a match
case _: GroupingAtom => PredicateType.GROUPING
case _ => if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE
, a.rId))
val deps = body.map(a => (if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE, a.rId))

val typeHelper = body.flatMap(a => a.terms.map(* => !a.negated))

Expand Down Expand Up @@ -137,36 +119,7 @@ object JoinIndexes {
)).to(mutable.Map)
)

//groupings
val groupingIndexes = precalculatedGroupingIndexes.getOrElse(
body.collect{ case ga: GroupingAtom => ga }.map(ga =>
val (varsp, ctans) = ga.gp.terms.zipWithIndex.partitionMap{
case (v: Variable, i) => Left((v, i))
case (c: Constant, i) => Right((c, i))
}
val vars = varsp.filterNot(_._1.anon)
val gis = ga.gv.map(v => vars.find(_._1 == v).get).map(_._2)
ga.hash -> GroupingJoinIndexes(
vars.groupBy(_._1).values.filter(_.size > 1).map(_.map(_._2)).toSeq,
ctans.map(_.swap).to(mutable.Map),
gis,
ga.ags.map(_._1).map(ao =>
val aoi = ao.t match
case v: Variable =>
val i = ga.gv.indexOf(v)
if i >= 0 then AggOpIndex.GV(gis(i)) else AggOpIndex.LV(vars.find(_._1 == v).get._2)
case c: Constant => AggOpIndex.C(c)
ao match
case AggOp.SUM(t) => (StorageAggOp.SUM, aoi)
case AggOp.COUNT(t) => (StorageAggOp.COUNT, aoi)
case AggOp.MIN(t) => (StorageAggOp.MIN, aoi)
case AggOp.MAX(t) => (StorageAggOp.MAX, aoi)
)
)
).toMap
)

new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, edb = false, groupingIndexes = groupingIndexes)
new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns)
}

// used to approximate poor user-defined order
Expand Down Expand Up @@ -261,7 +214,7 @@ object JoinIndexes {
presortSelect(sortBy, originalK, sm, -1)
val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate(
newHash,
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes))
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns))
)
(input.map(c => ProjectJoinFilterOp(rId, newK, newBody.map((_, oldP) => c.childrenSO(oldP)): _*)), newK)
}
Expand All @@ -282,15 +235,15 @@ object JoinIndexes {
presortSelect(sortBy, originalK, sm, deltaIdx)
val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate(
newHash,
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes))
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns))
)
(newK.atoms.drop(1).map(a => input(originalK.atoms.drop(1).indexOf(a))), newK)
}

def allOrders(rule: Seq[Atom]): AllIndexes = {
val idx = JoinIndexes(rule, None, None)
val idx = JoinIndexes(rule, None)
mutable.Map[String, JoinIndexes](rule.drop(1).permutations.map(r =>
val toRet = JoinIndexes(rule.head +: r, Some(idx.cxns), Some(idx.groupingIndexes))
val toRet = JoinIndexes(rule.head +: r, Some(idx.cxns))
toRet.hash -> toRet
).toSeq:_*)
}
Expand Down
4 changes: 0 additions & 4 deletions src/main/scala/datalog/execution/LambdaCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,4 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten
val clhs = compile(children.head)
val crhs = compile(children(1))
sm => sm.diff(clhs(sm), crhs(sm))

case GroupingOp(child, gji) =>
val clh = compile(child)
sm => sm.groupingHelper(clh(sm), gji)
}
9 changes: 2 additions & 7 deletions src/main/scala/datalog/execution/NaiveExecutionEngine.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package datalog.execution

import datalog.dsl.{Atom, Constant, Term, Variable, GroupingAtom, AggOp}
import datalog.storage.{RelationId, CollectionsStorageManager, StorageManager, EDB, StorageAggOp}
import datalog.dsl.{Atom, Constant, Term, Variable}
import datalog.storage.{RelationId, CollectionsStorageManager, StorageManager, EDB}
import datalog.tools.Debug.debug

import scala.collection.mutable
Expand Down Expand Up @@ -40,11 +40,6 @@ class NaiveExecutionEngine(val storageManager: StorageManager, stratified: Boole
val jIdx = getOperatorKey(rule)
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(jIdx)
storageManager.addConstantsToDomain(jIdx.constIndexes.values.toSeq)

// We need to add the constants occurring in the grouping predicates of the grouping atoms
rule.collect{ case ga: GroupingAtom => ga}.foreach(ga =>
storageManager.addConstantsToDomain(jIdx.groupingIndexes(ga.hash).constIndexes.values.toSeq)
)
}

def insertEDB(rule: Atom): Unit = {
Expand Down
13 changes: 6 additions & 7 deletions src/main/scala/datalog/execution/PrecedenceGraph.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package datalog.execution

import datalog.dsl.{Atom, GroupingAtom}
import datalog.dsl.Atom
import datalog.tools.Debug.debug
import datalog.storage.{NS, RelationId}

Expand Down Expand Up @@ -34,9 +34,8 @@ class PrecedenceGraph(using ns: NS /* ns used for pretty printing */) {

def addNode(rule: Seq[Atom]): Unit = {
idbs.addOne(rule.head.rId)
def cond(a: Atom) = !a.negated && !a.isInstanceOf[GroupingAtom] // We treat grouping atoms as if they were negated because negation and aggregation require the same stratification
adjacencyList.update(rule.head.rId, adjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(cond(_)).map(_.rId))
negAdjacencyList.update(rule.head.rId, negAdjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(!cond(_)).map(_.rId))
adjacencyList.update(rule.head.rId, adjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(!_.negated).map(_.rId))
negAdjacencyList.update(rule.head.rId, negAdjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(_.negated).map(_.rId))
}

def updateNodeAlias(aliases: mutable.Map[RelationId, RelationId]): Unit = {
Expand Down Expand Up @@ -132,11 +131,11 @@ class PrecedenceGraph(using ns: NS /* ns used for pretty printing */) {

val result = sorted.map(_.toSet).toSeq

// check for negative or grouping cycle (since both require the same check we treat grouping atoms as negated atoms in the dependency graph)
// check for negative cycle
result.foreach(strata =>
strata.foreach(p =>
if (graph(p).negEdges.map(n => n.rId).intersect(strata).nonEmpty)
throw new Exception("Negative or grouping cycle detected in input program")
throw new Exception("Negative cycle detected in input program")
)
)

Expand Down Expand Up @@ -164,7 +163,7 @@ class PrecedenceGraph(using ns: NS /* ns used for pretty printing */) {
)
)
if (stratum.nonEmpty && stratum.values.max > stratum.keys.size)
throw new Exception("Negative or grouping cycle detected in input program")
throw new Exception("Negative cycle detected in input program")
setDiff = prevStratum != stratum
prevStratum = stratum.clone
}
Expand Down
39 changes: 0 additions & 39 deletions src/main/scala/datalog/execution/QuoteCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
x match
case PredicateType.POSITIVE => '{ PredicateType.POSITIVE }
case PredicateType.NEGATED => '{ PredicateType.NEGATED }
case PredicateType.GROUPING => '{ PredicateType.GROUPING }
}
}

Expand All @@ -78,40 +77,6 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
}
}


given ToExpr[StorageAggOp] with {
def apply(x: StorageAggOp)(using Quotes) = {
x match
case StorageAggOp.SUM => '{ StorageAggOp.SUM }
case StorageAggOp.COUNT => '{ StorageAggOp.COUNT }
case StorageAggOp.MIN => '{ StorageAggOp.MIN }
case StorageAggOp.MAX => '{ StorageAggOp.MAX }
}
}

given ToExpr[AggOpIndex] with {
def apply(x: AggOpIndex)(using Quotes) = {
x match
case AggOpIndex.LV(i) => '{ AggOpIndex.LV(${ Expr(i) }) }
case AggOpIndex.GV(i) => '{ AggOpIndex.GV(${ Expr(i) }) }
case AggOpIndex.C(c) => '{ AggOpIndex.C(${ Expr(c) }) }

}
}

given ToExpr[GroupingJoinIndexes] with {
def apply(x: GroupingJoinIndexes)(using Quotes) = {
'{
GroupingJoinIndexes(
${ Expr(x.varIndexes) },
${ Expr(x.constIndexes) },
${ Expr(x.groupingIndexes) },
${ Expr(x.aggOpInfos) }
)
}
}
}

/**
* Compiles a relational operator into a quote that returns an EDB. Future TODO: merge with compileIR when dotty supports.
*/
Expand Down Expand Up @@ -200,10 +165,6 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
val crhs = compileIRRelOp(children(1))
'{ $stagedSM.diff($clhs, $crhs) }

case GroupingOp(child, gji) =>
val clh = compileIRRelOp(child)
'{ $stagedSM.groupingHelper($clh, ${ Expr(gji) }) }

case DebugPeek(prefix, msg, children: _*) =>
val res = compileIRRelOp(children.head)
'{ debug(${ Expr(prefix) }, () => s"${${ Expr(msg()) }}") ; $res }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ class SemiNaiveExecutionEngine(override val storageManager: StorageManager, stra
innerSolve(toSolve, relations.toSeq)
if (idx < strata.length - 1)
storageManager.updateDiscovered()
storageManager.clearKnownDelta() // KnownDelta is not always empty when moving to the next stratum,
// it is a problem for aggregates because it leads to the creation of unwanted tuples.
)
storageManager.getNewIDBResult(toSolve)
}
Expand Down
Loading

0 comments on commit 108ea62

Please sign in to comment.