Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Optimize negation (depends on stratified-aggregation) #38

Merged
merged 15 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/main/scala/datalog/execution/BytecodeCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,17 @@ class BytecodeCompiler(val storageManager: StorageManager)(using JITOptions) ext
.constantInstruction(rId)
emitSMCall(xb, meth, classOf[Int])

case ComplementOp(arity) =>
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
xb.aload(0)
.constantInstruction(arity)
emitSMCall(xb, "getComplement", classOf[Int])
xb.aload(0)
emitCols(xb, cols)
emitSMCall(xb, "getGroundOf", classOf[Seq[?]])
xb.aload(0)
traverse(xb, child)
emitSeq(xb, tmp.map(v => xxb => emitBoolean(xxb, v)))
emitSMCall(xb, "zeroOut", classOf[EDB], classOf[Seq[?]])
emitSMCall(xb, "diff", classOf[EDB], classOf[EDB])

case ScanEDBOp(rId) =>
xb.aload(0)
Expand Down
36 changes: 35 additions & 1 deletion src/main/scala/datalog/execution/BytecodeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ object BytecodeGenerator {
else
xb.constantInstruction(0)

/** Emit `Boolean.valueOf($value)`. */
def emitBoolean(xb: CodeBuilder, value: Boolean): Unit =
xb.constantInstruction(if value then 1 else 0)
.invokestatic(clsDesc(classOf[java.lang.Boolean]), "valueOf",
MethodTypeDesc.of(clsDesc(classOf[java.lang.Boolean]), clsDesc(classOf[Boolean])))

def emitSeqInt(xb: CodeBuilder, value: Seq[Int]): Unit =
emitSeq(xb, value.map(v => xxb => emitInteger(xxb, v)))

Expand Down Expand Up @@ -248,6 +254,17 @@ object BytecodeGenerator {
}
}

def emitEither[A, B](xb: CodeBuilder, either: Either[A, B], emitA: (CodeBuilder, A) => Unit, emitB: (CodeBuilder, B) => Unit): Unit =
either match
case Left(value) =>
emitNew(xb, classOf[Left[A, B]], { xxb =>
emitA(xxb, value)
})
case Right(value) =>
emitNew(xb, classOf[Right[A, B]], { xxb =>
emitB(xxb, value)
})

def emitProjIndexes(xb: CodeBuilder, value: Seq[(String, Constant)]): Unit =
emitSeq(xb, value.map(v => xxb => emitStringConstantTuple2(xxb, v)))

Expand All @@ -268,6 +285,7 @@ object BytecodeGenerator {
def emitCxns(xb: CodeBuilder, value: collection.mutable.Map[String, collection.mutable.Map[Int, Seq[String]]]): Unit =
emitMap(xb, value.toSeq, emitString, emitCxnElement)

/*
def emitJoinIndexes(xb: CodeBuilder, value: JoinIndexes): Unit =
emitNew(xb, classOf[JoinIndexes], xxb =>
emitVarIndexes(xxb, value.varIndexes)
Expand All @@ -277,7 +295,11 @@ object BytecodeGenerator {
// emitArrayAtoms(xxb, value.atoms)
emitSeq(xb, value.atoms.map(a => xxb => emitAtom(xxb, a)))
emitCxns(xxb, value.cxns)
emitBool(xxb, value.edb))
// TODO: Missing negationInfo!
emitBool(xxb, value.edb),
// TODO: Missing groupingInfos!
)
*/

def emitStorageAggOp(xb: CodeBuilder, sao: StorageAggOp): Unit =
val enumCompanionCls = classOf[StorageAggOp.type]
Expand Down Expand Up @@ -315,6 +337,18 @@ object BytecodeGenerator {
emitSeqInt(xxb, value.groupingIndexes)
emitAggOpInfos(xxb, value.aggOpInfos))

def emitCols(xb: CodeBuilder, value: Seq[Either[Constant, Seq[(RelationId, Int)]]]): Unit =
emitSeq(xb, value.map(v => xxb =>
emitEither(xxb, v, emitConstant, (xxxb, s) =>
emitSeq(xxxb, s.map(vv => xxxxb =>
emitNew(xxxxb, classOf[(Int, Int)], xxxxxb =>
emitInteger(xxxxxb, vv._1)
emitInteger(xxxxxb, vv._2)
)
))
)
))

val CD_BoxedUnit = clsDesc(classOf[scala.runtime.BoxedUnit])

/** Emit `BoxedUnit.UNIT`. */
Expand Down
44 changes: 31 additions & 13 deletions src/main/scala/datalog/execution/JoinIndexes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ case class GroupingJoinIndexes(varIndexes: Seq[Seq[Int]],
* @param edb - for rules that have EDBs defined on the same predicate, just read
* @param atoms - the original atoms from the DSL
* @param cxns - convenience data structure tracking how many variables in common each atom has with every other atom.
* @param negationInfo - information needed to build the complement relation of negated atoms: for each term, either a constant or a list of pairs (relationid, column) of the ocurrences of the variable in the rule (empty for anonynous variable)
*/
case class JoinIndexes(varIndexes: Seq[Seq[Int]],
constIndexes: mutable.Map[Int, Constant],
projIndexes: Seq[(String, Constant)],
deps: Seq[(PredicateType, RelationId)],
atoms: Seq[Atom],
cxns: mutable.Map[String, mutable.Map[Int, Seq[String]]],
negationInfo: Map[String, Seq[Either[Constant, Seq[(RelationId, Int)]]]],
edb: Boolean = false,
groupingIndexes: Map[String, GroupingJoinIndexes] = Map.empty
) {
Expand All @@ -54,6 +56,7 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
", deps:" + depsToString(ns) +
", edb:" + edb +
", cxn: " + cxnsToString(ns) +
", negation: " + negationToString(ns) +
" }"

def varToString(): String = varIndexes.map(v => v.mkString("$", "==$", "")).mkString("[", ",", "]")
Expand All @@ -66,6 +69,13 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
inCommon.map((count, hashs) =>
count.toString + ": " + hashs.map(h => ns.hashToAtom(h)).mkString("", "|", "")
).mkString("", ", ", "")} }").mkString("[", ",\n", "]")
def negationToString(ns: NS): String =
negationInfo.map((h, infos) =>
s"{ ${ns.hashToAtom(h)} => ${
infos.map{
case Left(value) => value
case Right(value) => s"[ ${value.map((r, c) => s"(${ns(r)}, $c)")} ]"
}} }").mkString("[", ",\n", "]")
val hash: String = atoms.map(a => a.hash).mkString("", "", "")
}

Expand All @@ -83,23 +93,19 @@ object JoinIndexes {
case _ => if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE
, a.rId))

val typeHelper = body.flatMap(a => a.terms.map(* => !a.negated))

val bodyVars = body
.flatMap(a => a.terms) // all terms in one seq
.flatMap(a => a.terms.zipWithIndex.map((t, i) => (t, (a.negated, a.isInstanceOf[GroupingAtom] && i >= a.asInstanceOf[GroupingAtom].gv.length)))) // all terms in one seq
.zipWithIndex // term, position
.groupBy(z => z._1) // group by term
.groupBy(z => z._1._1) // group by term
.filter((term, matches) => // matches = Seq[(var, pos1), (var, pos2), ...]
term match {
case v: Variable =>
matches.map(_._2).find(typeHelper) match
case Some(pos) =>
variables(v) = pos
case None =>
if (v.oid != -1)
throw new Exception(s"Variable with varId ${v.oid} appears only in negated rules")
else
()
val wrong = v.oid != -1 && matches.exists(_._1._2._1) && matches.forall(x => x._1._2._1 || x._1._2._2) // Var occurs negated and all occurrences are either negated or aggregated
if wrong then
throw new Exception(s"Variable with varId ${v.oid} appears only in negated atoms (and possibly in aggregated positions of grouping atoms)")
else
if (v.oid != -1)
variables(v) = matches.find(!_._1._2._1).get._2
!v.anon && matches.length >= 2
case c: Constant =>
matches.foreach((_, idx) => constants(idx) = c)
Expand Down Expand Up @@ -137,6 +143,18 @@ object JoinIndexes {
)).to(mutable.Map)
)


val variables2 = body.filterNot(_.negated).flatMap(a =>
a.terms.zipWithIndex.collect{ case (v: Variable, i) if !v.anon => (v, i) }.map((v, i) => (v, (a.rId, i)))
).groupBy(_._1).view.mapValues(_.map(_._2))

val negationInfo = body.filter(_.negated).map(a =>
a.hash -> a.terms.map{
case c: Constant => Left(c)
case v: Variable => Right(if v.anon then Seq() else variables2(v))
}
).toMap

//groupings
val groupingIndexes = precalculatedGroupingIndexes.getOrElse(
body.collect{ case ga: GroupingAtom => ga }.map(ga =>
Expand Down Expand Up @@ -166,7 +184,7 @@ object JoinIndexes {
).toMap
)

new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, edb = false, groupingIndexes = groupingIndexes)
new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, negationInfo, edb = false, groupingIndexes = groupingIndexes)
}

// used to approximate poor user-defined order
Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/datalog/execution/LambdaCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,13 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten
}
}

case ComplementOp(arity) =>
_.getComplement(arity)
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
val clh = compile(child)
sm =>
val compl = sm.getGroundOf(cols)
val nq = sm.zeroOut(clh(sm), tmp)
sm.diff(compl, nq)

case ScanEDBOp(rId) =>
if (storageManager.edbContains(rId))
Expand Down
8 changes: 1 addition & 7 deletions src/main/scala/datalog/execution/NaiveExecutionEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,11 @@ class NaiveExecutionEngine(val storageManager: StorageManager, stratified: Boole
idbs.getOrElseUpdate(rId, mutable.ArrayBuffer[IndexedSeq[Atom]]()).addOne(rule.toIndexedSeq)
val jIdx = getOperatorKey(rule)
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(jIdx)
storageManager.addConstantsToDomain(jIdx.constIndexes.values.toSeq)

// We need to add the constants occurring in the grouping predicates of the grouping atoms
rule.collect{ case ga: GroupingAtom => ga}.foreach(ga =>
storageManager.addConstantsToDomain(jIdx.groupingIndexes(ga.hash).constIndexes.values.toSeq)
)
}

def insertEDB(rule: Atom): Unit = {
if (!storageManager.edbContains(rule.rId))
prebuiltOpKeys.getOrElseUpdate(rule.rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(JoinIndexes(IndexedSeq(), mutable.Map(), IndexedSeq(), Seq((PredicateType.POSITIVE, rule.rId)), Seq(rule), mutable.Map.empty, true))
prebuiltOpKeys.getOrElseUpdate(rule.rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(JoinIndexes(IndexedSeq(), mutable.Map(), IndexedSeq(), Seq((PredicateType.POSITIVE, rule.rId)), Seq(rule), mutable.Map.empty, Map.empty, true))
storageManager.insertEDB(rule)
}

Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/datalog/execution/QuoteCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
${ Expr(x.deps) },
${ Expr(x.atoms) },
${ Expr(x.cxns) },
${ Expr(x.negationInfo) },
${ Expr(x.edb) }
)
}
Expand Down Expand Up @@ -135,8 +136,14 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
}
}

case ComplementOp(arity) =>
'{ $stagedSM.getComplement(${ Expr(arity) }) }
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
val clh = compileIRRelOp(child)
'{
val compl = $stagedSM.getGroundOf(${ Expr(cols) })
val nq = $stagedSM.zeroOut($clh, ${ Expr(tmp) })
$stagedSM.diff(compl, nq)
}

case ScanEDBOp(rId) =>
if (storageManager.edbContains(rId))
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/datalog/execution/StagedExecutionEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp
case op: ScanEDBOp =>
op.run(storageManager)

case op: ComplementOp =>
op.run(storageManager)
case op: NegationOp =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))

case op: ProjectJoinFilterOp =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/datalog/execution/StagedSnippetCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt
${ Expr(x.deps) },
${ Expr(x.atoms) },
${ Expr(x.cxns) },
${ Expr(x.negationInfo) },
${ Expr(x.edb) },
) }
}
Expand Down Expand Up @@ -127,8 +128,13 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt
}
}

case ComplementOp(arity) =>
'{ $stagedSM.getComplement(${ Expr(arity) }) }
case NegationOp(child, cols) =>
val tmp = cols.map(_.exists(_.isEmpty))
'{
val compl = $stagedSM.getGroundOf(${ Expr(cols) })
val nq = $stagedSM.zeroOut($stagedFns.head($stagedSM), ${ Expr(tmp) })
$stagedSM.diff(compl, nq)
}

case ScanEDBOp(rId) =>
if (storageManager.edbContains(rId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ class StagedSnippetExecutionEngine(override val storageManager: StorageManager,
case op: DebugPeek =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))

case op: ComplementOp =>
op.run(storageManager)
case op: NegationOp =>
op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o)))

case _ => throw new Exception(s"Error: interpretRelOp called with unit operation: code=${irTree.code}")
}
Expand Down
16 changes: 11 additions & 5 deletions src/main/scala/datalog/execution/ir/IROp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.util.{Failure, Success}
enum OpCode:
case PROGRAM, SWAP_CLEAR, SEQ,
SCAN, SCANEDB, SCAN_DISCOVERED,
COMPLEMENT,
NEGATION,
SPJ, INSERT, UNION, DIFF,
GROUPING,
DEBUG, DEBUGP, DOWHILE, UPDATE_DISCOVERED,
Expand Down Expand Up @@ -196,14 +196,20 @@ case class InsertOp(rId: RelationId, db: DB, knowledge: KNOWLEDGE, override val
}
}

case class ComplementOp(arity: Int)(using JITOptions) extends IROp[EDB] {
val code: OpCode = OpCode.COMPLEMENT
case class NegationOp(child: IROp[EDB], cols: Seq[Either[Constant, Seq[(RelationId, Int)]]])(using JITOptions) extends IROp[EDB](child) {
val code: OpCode = OpCode.NEGATION

override def run(storageManager: StorageManager): EDB =
storageManager.getComplement(arity)
val tmp = cols.map(_.exists(_.isEmpty))
val compl = storageManager.getGroundOf(cols)
val nq = storageManager.zeroOut(child.run(storageManager), tmp)
storageManager.diff(compl, nq)

override def run_continuation(storageManager: StorageManager, opFns: Seq[CompiledFn[EDB]]): EDB =
run(storageManager) // bc leaf node, no difference for continuation or run
val tmp = cols.map(_.exists(_.isEmpty))
val compl = storageManager.getGroundOf(cols)
val nq = storageManager.zeroOut(opFns(0)(storageManager), tmp)
storageManager.diff(compl, nq)
}

case class ScanOp(rId: RelationId, db: DB, knowledge: KNOWLEDGE)(using JITOptions) extends IROp[EDB] {
Expand Down
12 changes: 6 additions & 6 deletions src/main/scala/datalog/execution/ir/IRTreeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) {
val q = ScanOp(r, DB.Derived, KNOWLEDGE.Known)
typ match
case PredicateType.NEGATED =>
val arity = k.atoms(i + 1).terms.length
val res = DiffOp(ComplementOp(arity), q)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}\n\tarity=$arity")
val cols = k.negationInfo(k.atoms(i + 1).hash)
val res = NegationOp(q, cols)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}")
res
case PredicateType.GROUPING =>
val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom]
Expand Down Expand Up @@ -119,9 +119,9 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) {
ScanOp(r, DB.Derived, KNOWLEDGE.Known)
typ match
case PredicateType.NEGATED =>
val arity = k.atoms(i + 1).terms.length
val res = DiffOp(ComplementOp(arity), q)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}\n\tarity=$arity")
val cols = k.negationInfo(k.atoms(i + 1).hash)
val res = NegationOp(q, cols)
debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}")
res
case PredicateType.GROUPING =>
val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom]
Expand Down
Loading
Loading