Skip to content

Commit

Permalink
WIP: Optimize negation (depends on stratified-aggregation) (#38)
Browse files Browse the repository at this point in the history
* Stratified aggregation

* Add some tests and examples, fix broken tests

* Handle anonymous variables in grouping atoms, add a few examples

* Simplify Grouping volcano operator, add constants generated by aggregates to domain

* Remove unused AST parameter

* Move creation of static operations, finish StagedSnippet

* Optimize negation: avoid combinatorial explosion

* Add synthetic test

* Add bugfix comment

* Optimize negation: avoid combinatorial explosion

* Add synthetic test

* Quick fix: Prohibit negated variables from being guarded by aggregated variables

* Simplify negation suboperations
  • Loading branch information
guillembartrina authored Dec 14, 2023
1 parent d655b18 commit 59f7e69
Show file tree
Hide file tree
Showing 20 changed files with 236 additions and 106 deletions.
13 changes: 10 additions & 3 deletions src/main/scala/datalog/execution/BytecodeCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ class BytecodeCompiler(val storageManager: StorageManager)(using JITOptions) ext
.constantInstruction(rId)
emitSMCall(xb, meth, classOf[Int])

case ComplementOp(arity) =>
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
xb.aload(0)
.constantInstruction(arity)
emitSMCall(xb, "getComplement", classOf[Int])
xb.aload(0)
emitCols(xb, cols)
emitSMCall(xb, "getGroundOf", classOf[Seq[?]])
xb.aload(0)
traverse(xb, child)
emitSeq(xb, tmp.map(v => xxb => emitBoolean(xxb, v)))
emitSMCall(xb, "zeroOut", classOf[EDB], classOf[Seq[?]])
emitSMCall(xb, "diff", classOf[EDB], classOf[EDB])

case ScanEDBOp(rId) =>
xb.aload(0)
Expand Down
36 changes: 35 additions & 1 deletion src/main/scala/datalog/execution/BytecodeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ object BytecodeGenerator {
else
xb.constantInstruction(0)

/** Emit `Boolean.valueOf($value)`. */
def emitBoolean(xb: CodeBuilder, value: Boolean): Unit =
xb.constantInstruction(if value then 1 else 0)
.invokestatic(clsDesc(classOf[java.lang.Boolean]), "valueOf",
MethodTypeDesc.of(clsDesc(classOf[java.lang.Boolean]), clsDesc(classOf[Boolean])))

def emitSeqInt(xb: CodeBuilder, value: Seq[Int]): Unit =
emitSeq(xb, value.map(v => xxb => emitInteger(xxb, v)))

Expand Down Expand Up @@ -248,6 +254,17 @@ object BytecodeGenerator {
}
}

def emitEither[A, B](xb: CodeBuilder, either: Either[A, B], emitA: (CodeBuilder, A) => Unit, emitB: (CodeBuilder, B) => Unit): Unit =
either match
case Left(value) =>
emitNew(xb, classOf[Left[A, B]], { xxb =>
emitA(xxb, value)
})
case Right(value) =>
emitNew(xb, classOf[Right[A, B]], { xxb =>
emitB(xxb, value)
})

def emitProjIndexes(xb: CodeBuilder, value: Seq[(String, Constant)]): Unit =
emitSeq(xb, value.map(v => xxb => emitStringConstantTuple2(xxb, v)))

Expand All @@ -268,6 +285,7 @@ object BytecodeGenerator {
def emitCxns(xb: CodeBuilder, value: collection.mutable.Map[String, collection.mutable.Map[Int, Seq[String]]]): Unit =
emitMap(xb, value.toSeq, emitString, emitCxnElement)

/*
def emitJoinIndexes(xb: CodeBuilder, value: JoinIndexes): Unit =
emitNew(xb, classOf[JoinIndexes], xxb =>
emitVarIndexes(xxb, value.varIndexes)
Expand All @@ -277,7 +295,11 @@ object BytecodeGenerator {
// emitArrayAtoms(xxb, value.atoms)
emitSeq(xb, value.atoms.map(a => xxb => emitAtom(xxb, a)))
emitCxns(xxb, value.cxns)
emitBool(xxb, value.edb))
// TODO: Missing negationInfo!
emitBool(xxb, value.edb),
// TODO: Missing groupingInfos!
)
*/

def emitStorageAggOp(xb: CodeBuilder, sao: StorageAggOp): Unit =
val enumCompanionCls = classOf[StorageAggOp.type]
Expand Down Expand Up @@ -315,6 +337,18 @@ object BytecodeGenerator {
emitSeqInt(xxb, value.groupingIndexes)
emitAggOpInfos(xxb, value.aggOpInfos))

def emitCols(xb: CodeBuilder, value: Seq[Either[Constant, Seq[(RelationId, Int)]]]): Unit =
emitSeq(xb, value.map(v => xxb =>
emitEither(xxb, v, emitConstant, (xxxb, s) =>
emitSeq(xxxb, s.map(vv => xxxxb =>
emitNew(xxxxb, classOf[(Int, Int)], xxxxxb =>
emitInteger(xxxxxb, vv._1)
emitInteger(xxxxxb, vv._2)
)
))
)
))

val CD_BoxedUnit = clsDesc(classOf[scala.runtime.BoxedUnit])

/** Emit `BoxedUnit.UNIT`. */
Expand Down
44 changes: 31 additions & 13 deletions src/main/scala/datalog/execution/JoinIndexes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ case class GroupingJoinIndexes(varIndexes: Seq[Seq[Int]],
* @param edb - for rules that have EDBs defined on the same predicate, just read
* @param atoms - the original atoms from the DSL
* @param cxns - convenience data structure tracking how many variables in common each atom has with every other atom.
* @param negationInfo - information needed to build the complement relation of negated atoms: for each term, either a constant or a list of pairs (relationid, column) of the ocurrences of the variable in the rule (empty for anonynous variable)
*/
case class JoinIndexes(varIndexes: Seq[Seq[Int]],
constIndexes: mutable.Map[Int, Constant],
projIndexes: Seq[(String, Constant)],
deps: Seq[(PredicateType, RelationId)],
atoms: Seq[Atom],
cxns: mutable.Map[String, mutable.Map[Int, Seq[String]]],
negationInfo: Map[String, Seq[Either[Constant, Seq[(RelationId, Int)]]]],
edb: Boolean = false,
groupingIndexes: Map[String, GroupingJoinIndexes] = Map.empty
) {
Expand All @@ -54,6 +56,7 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
", deps:" + depsToString(ns) +
", edb:" + edb +
", cxn: " + cxnsToString(ns) +
", negation: " + negationToString(ns) +
" }"

def varToString(): String = varIndexes.map(v => v.mkString("$", "==$", "")).mkString("[", ",", "]")
Expand All @@ -66,6 +69,13 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
inCommon.map((count, hashs) =>
count.toString + ": " + hashs.map(h => ns.hashToAtom(h)).mkString("", "|", "")
).mkString("", ", ", "")} }").mkString("[", ",\n", "]")
def negationToString(ns: NS): String =
negationInfo.map((h, infos) =>
s"{ ${ns.hashToAtom(h)} => ${
infos.map{
case Left(value) => value
case Right(value) => s"[ ${value.map((r, c) => s"(${ns(r)}, $c)")} ]"
}} }").mkString("[", ",\n", "]")
val hash: String = atoms.map(a => a.hash).mkString("", "", "")
}

Expand All @@ -83,23 +93,19 @@ object JoinIndexes {
case _ => if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE
, a.rId))

val typeHelper = body.flatMap(a => a.terms.map(* => !a.negated))

val bodyVars = body
.flatMap(a => a.terms) // all terms in one seq
.flatMap(a => a.terms.zipWithIndex.map((t, i) => (t, (a.negated, a.isInstanceOf[GroupingAtom] && i >= a.asInstanceOf[GroupingAtom].gv.length)))) // all terms in one seq
.zipWithIndex // term, position
.groupBy(z => z._1) // group by term
.groupBy(z => z._1._1) // group by term
.filter((term, matches) => // matches = Seq[(var, pos1), (var, pos2), ...]
term match {
case v: Variable =>
matches.map(_._2).find(typeHelper) match
case Some(pos) =>
variables(v) = pos
case None =>
if (v.oid != -1)
throw new Exception(s"Variable with varId ${v.oid} appears only in negated rules")
else
()
val wrong = v.oid != -1 && matches.exists(_._1._2._1) && matches.forall(x => x._1._2._1 || x._1._2._2) // Var occurs negated and all occurrences are either negated or aggregated
if wrong then
throw new Exception(s"Variable with varId ${v.oid} appears only in negated atoms (and possibly in aggregated positions of grouping atoms)")
else
if (v.oid != -1)
variables(v) = matches.find(!_._1._2._1).get._2
!v.anon && matches.length >= 2
case c: Constant =>
matches.foreach((_, idx) => constants(idx) = c)
Expand Down Expand Up @@ -137,6 +143,18 @@ object JoinIndexes {
)).to(mutable.Map)
)


val variables2 = body.filterNot(_.negated).flatMap(a =>
a.terms.zipWithIndex.collect{ case (v: Variable, i) if !v.anon => (v, i) }.map((v, i) => (v, (a.rId, i)))
).groupBy(_._1).view.mapValues(_.map(_._2))

val negationInfo = body.filter(_.negated).map(a =>
a.hash -> a.terms.map{
case c: Constant => Left(c)
case v: Variable => Right(if v.anon then Seq() else variables2(v))
}
).toMap

//groupings
val groupingIndexes = precalculatedGroupingIndexes.getOrElse(
body.collect{ case ga: GroupingAtom => ga }.map(ga =>
Expand Down Expand Up @@ -166,7 +184,7 @@ object JoinIndexes {
).toMap
)

new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, edb = false, groupingIndexes = groupingIndexes)
new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, negationInfo, edb = false, groupingIndexes = groupingIndexes)
}

// used to approximate poor user-defined order
Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/datalog/execution/LambdaCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,13 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten
}
}

case ComplementOp(arity) =>
_.getComplement(arity)
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
val clh = compile(child)
sm =>
val compl = sm.getGroundOf(cols)
val nq = sm.zeroOut(clh(sm), tmp)
sm.diff(compl, nq)

case ScanEDBOp(rId) =>
if (storageManager.edbContains(rId))
Expand Down
8 changes: 1 addition & 7 deletions src/main/scala/datalog/execution/NaiveExecutionEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,11 @@ class NaiveExecutionEngine(val storageManager: StorageManager, stratified: Boole
idbs.getOrElseUpdate(rId, mutable.ArrayBuffer[IndexedSeq[Atom]]()).addOne(rule.toIndexedSeq)
val jIdx = getOperatorKey(rule)
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(jIdx)
storageManager.addConstantsToDomain(jIdx.constIndexes.values.toSeq)

// We need to add the constants occurring in the grouping predicates of the grouping atoms
rule.collect{ case ga: GroupingAtom => ga}.foreach(ga =>
storageManager.addConstantsToDomain(jIdx.groupingIndexes(ga.hash).constIndexes.values.toSeq)
)
}

def insertEDB(rule: Atom): Unit = {
if (!storageManager.edbContains(rule.rId))
prebuiltOpKeys.getOrElseUpdate(rule.rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(JoinIndexes(IndexedSeq(), mutable.Map(), IndexedSeq(), Seq((PredicateType.POSITIVE, rule.rId)), Seq(rule), mutable.Map.empty, true))
prebuiltOpKeys.getOrElseUpdate(rule.rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(JoinIndexes(IndexedSeq(), mutable.Map(), IndexedSeq(), Seq((PredicateType.POSITIVE, rule.rId)), Seq(rule), mutable.Map.empty, Map.empty, true))
storageManager.insertEDB(rule)
}

Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/datalog/execution/QuoteCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
${ Expr(x.deps) },
${ Expr(x.atoms) },
${ Expr(x.cxns) },
${ Expr(x.negationInfo) },
${ Expr(x.edb) }
)
}
Expand Down Expand Up @@ -135,8 +136,14 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
}
}

case ComplementOp(arity) =>
'{ $stagedSM.getComplement(${ Expr(arity) }) }
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
val clh = compileIRRelOp(child)
'{
val compl = $stagedSM.getGroundOf(${ Expr(cols) })
val nq = $stagedSM.zeroOut($clh, ${ Expr(tmp) })
$stagedSM.diff(compl, nq)
}

case ScanEDBOp(rId) =>
if (storageManager.edbContains(rId))
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/datalog/execution/StagedExecutionEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp
case op: ScanEDBOp =>
op.run(storageManager)

case op: ComplementOp =>
op.run(storageManager)
case op: NegationOp =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))

case op: ProjectJoinFilterOp =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/datalog/execution/StagedSnippetCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt
${ Expr(x.deps) },
${ Expr(x.atoms) },
${ Expr(x.cxns) },
${ Expr(x.negationInfo) },
${ Expr(x.edb) },
) }
}
Expand Down Expand Up @@ -127,8 +128,13 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt
}
}

case ComplementOp(arity) =>
'{ $stagedSM.getComplement(${ Expr(arity) }) }
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
'{
val compl = $stagedSM.getGroundOf(${ Expr(cols) })
val nq = $stagedSM.zeroOut($stagedFns.head($stagedSM), ${ Expr(tmp) })
$stagedSM.diff(compl, nq)
}

case ScanEDBOp(rId) =>
if (storageManager.edbContains(rId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ class StagedSnippetExecutionEngine(override val storageManager: StorageManager,
case op: DebugPeek =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))

case op: ComplementOp =>
op.run(storageManager)
case op: NegationOp =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))

case _ => throw new Exception(s"Error: interpretRelOp called with unit operation: code=${irTree.code}")
}
Expand Down
16 changes: 11 additions & 5 deletions src/main/scala/datalog/execution/ir/IROp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.util.{Failure, Success}
enum OpCode:
case PROGRAM, SWAP_CLEAR, SEQ,
SCAN, SCANEDB, SCAN_DISCOVERED,
COMPLEMENT,
NEGATION,
SPJ, INSERT, UNION, DIFF,
GROUPING,
DEBUG, DEBUGP, DOWHILE, UPDATE_DISCOVERED,
Expand Down Expand Up @@ -196,14 +196,20 @@ case class InsertOp(rId: RelationId, db: DB, knowledge: KNOWLEDGE, override val
}
}

case class ComplementOp(arity: Int)(using JITOptions) extends IROp[EDB] {
val code: OpCode = OpCode.COMPLEMENT
case class NegationOp(child: IROp[EDB], cols: Seq[Either[Constant, Seq[(RelationId, Int)]]])(using JITOptions) extends IROp[EDB](child) {
val code: OpCode = OpCode.NEGATION

override def run(storageManager: StorageManager): EDB =
storageManager.getComplement(arity)
val tmp = cols.map(_.exists(_.isEmpty))
val compl = storageManager.getGroundOf(cols)
val nq = storageManager.zeroOut(child.run(storageManager), tmp)
storageManager.diff(compl, nq)

override def run_continuation(storageManager: StorageManager, opFns: Seq[CompiledFn[EDB]]): EDB =
run(storageManager) // bc leaf node, no difference for continuation or run
val tmp = cols.map(_.exists(_.isEmpty))
val compl = storageManager.getGroundOf(cols)
val nq = storageManager.zeroOut(opFns(0)(storageManager), tmp)
storageManager.diff(compl, nq)
}

case class ScanOp(rId: RelationId, db: DB, knowledge: KNOWLEDGE)(using JITOptions) extends IROp[EDB] {
Expand Down
12 changes: 6 additions & 6 deletions src/main/scala/datalog/execution/ir/IRTreeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) {
val q = ScanOp(r, DB.Derived, KNOWLEDGE.Known)
typ match
case PredicateType.NEGATED =>
val arity = k.atoms(i + 1).terms.length
val res = DiffOp(ComplementOp(arity), q)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}\n\tarity=$arity")
val cols = k.negationInfo(k.atoms(i + 1).hash)
val res = NegationOp(q, cols)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}")
res
case PredicateType.GROUPING =>
val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom]
Expand Down Expand Up @@ -119,9 +119,9 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) {
ScanOp(r, DB.Derived, KNOWLEDGE.Known)
typ match
case PredicateType.NEGATED =>
val arity = k.atoms(i + 1).terms.length
val res = DiffOp(ComplementOp(arity), q)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}\n\tarity=$arity")
val cols = k.negationInfo(k.atoms(i + 1).hash)
val res = NegationOp(q, cols)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}")
res
case PredicateType.GROUPING =>
val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom]
Expand Down
Loading

0 comments on commit 59f7e69

Please sign in to comment.