Skip to content

Commit 7d0894e

Browse files
committed
Revert "Stratified aggregation (#37)"
This reverts commit d655b18.
1 parent d655b18 commit 7d0894e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+68
-829
lines changed

src/main/scala/datalog/dsl/DSL.scala

-32
Original file line numberDiff line numberDiff line change
@@ -58,35 +58,3 @@ case class Relation[T <: Constant](id: Int, name: String)(using ee: ExecutionEng
5858
def solve(): Set[Seq[Term]] = ee.solve(id).map(s => s.toSeq).toSet
5959
def get(): Set[Seq[Term]] = ee.get(id)
6060
}
61-
62-
63-
enum AggOp(val t: Term):
64-
case SUM(override val t: Term) extends AggOp(t)
65-
case COUNT(override val t: Term) extends AggOp(t)
66-
case MIN(override val t: Term) extends AggOp(t)
67-
case MAX(override val t: Term) extends AggOp(t)
68-
69-
case class GroupingAtom(gp: Atom, gv: Seq[Variable], ags: Seq[(AggOp, Variable)])
70-
extends Atom(gp.rId, gv ++ ags.map(_._2), false):
71-
// We set the relation id of the grouping predicate because the 'virtual' relation will be computed from it and also because we need it to be so for certain logic: dep in JoinIndexes, node id in DependencyGraph, etc.
72-
override val hash: String = s"GB${gp.hash}-${gv.mkString("", "", "")}-${ags.mkString("", "", "")}"
73-
74-
object groupBy:
75-
def apply(gp: Atom, gv: Seq[Variable], ags: (AggOp, Variable)*): GroupingAtom =
76-
if (gp.negated)
77-
throw new Exception("The grouping predicate cannot be negated")
78-
if (gv.size != gv.distinct.size)
79-
throw new Exception("The grouping variables cannot be repeated")
80-
if (ags.map(_._2).size != ags.map(_._2).distinct.size)
81-
throw new Exception("The aggregation variables cannot be repeated")
82-
val gpVars = gp.terms.collect{ case v: Variable => v }.toSet
83-
val gVars = gv.toSet
84-
val aggVars = ags.map(_._2).toSet
85-
val aggdVars = ags.map(_._1.t).collect{ case v: Variable => v }.toSet
86-
if (gVars.contains(__) || aggVars.contains(__) || aggdVars.contains(__))
87-
throw new Exception("Anonymous variable ('__') not allowed as a grouping variable, aggregation variable or aggregated variable")
88-
if (aggVars.intersect(gpVars).nonEmpty)
89-
throw new Exception("No aggregation variable must not occur in the grouping predicate")
90-
if (!(aggdVars ++ gVars).subsetOf(gpVars))
91-
throw new Exception("The aggregated variables and the grouping variables must occurr in the grouping predicate")
92-
GroupingAtom(gp, gv, ags)

src/main/scala/datalog/execution/BytecodeCompiler.scala

-6
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,6 @@ class BytecodeCompiler(val storageManager: StorageManager)(using JITOptions) ext
164164
traverse(xb, children(1))
165165
emitSMCall(xb, "diff", classOf[EDB], classOf[EDB])
166166

167-
case GroupingOp(child, gji) =>
168-
xb.aload(0)
169-
traverse(xb, child)
170-
emitGroupingJoinIndexes(xb, gji)
171-
emitSMCall(xb, "groupingHelper", classOf[EDB], classOf[GroupingJoinIndexes])
172-
173167
case DebugPeek(prefix, msg, children: _*) =>
174168
assert(false, s"Unimplemented node: $irTree")
175169

src/main/scala/datalog/execution/BytecodeGenerator.scala

+6-43
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package datalog.execution
22

33
import datalog.dsl.{Atom, Constant, Term, Variable}
4-
import datalog.execution.AggOpIndex
5-
import datalog.storage.{RelationId, StorageAggOp}
4+
import datalog.storage.RelationId
65

76
import java.lang.constant.ConstantDescs.*
87
import java.lang.constant.*
@@ -193,11 +192,11 @@ object BytecodeGenerator {
193192
case e: Int => emitInteger(xb, e)
194193
case s: String => emitString(xb, s)
195194

196-
def emitStringConstantTuple2(xb: CodeBuilder, t: (String, Constant)): Unit =
197-
emitNew(xb, classOf[(String, Constant)], xxb =>
198-
emitString(xxb, t._1)
199-
emitConstant(xxb, t._2)
200-
)
195+
def emitStringConstantTuple2(xb: CodeBuilder, t: (String, Constant)): Unit =
196+
emitNew(xb, classOf[(String, Constant)], xxb =>
197+
emitString(xxb, t._1)
198+
emitConstant(xxb, t._2)
199+
)
201200

202201
def emitVariable(xb: CodeBuilder, variable: Variable): Unit =
203202
emitNew(xb, classOf[Variable], xxb =>
@@ -279,42 +278,6 @@ object BytecodeGenerator {
279278
emitCxns(xxb, value.cxns)
280279
emitBool(xxb, value.edb))
281280

282-
def emitStorageAggOp(xb: CodeBuilder, sao: StorageAggOp): Unit =
283-
val enumCompanionCls = classOf[StorageAggOp.type]
284-
emitObject(xb, enumCompanionCls)
285-
xb.constantInstruction(sao.ordinal)
286-
emitCall(xb, enumCompanionCls, "fromOrdinal", classOf[Int])
287-
288-
def emitAggOpIndex(xb: CodeBuilder, aoi: AggOpIndex): Unit = aoi match {
289-
case gv: AggOpIndex.GV =>
290-
emitNew(xb, classOf[AggOpIndex.GV], xxb =>
291-
emitInteger(xxb, gv.i)
292-
)
293-
case lv: AggOpIndex.LV =>
294-
emitNew(xb, classOf[AggOpIndex.GV], xxb =>
295-
emitInteger(xxb, lv.i)
296-
)
297-
case c: AggOpIndex.C =>
298-
emitNew(xb, classOf[AggOpIndex.C], xxb =>
299-
emitConstant(xxb, c.c)
300-
)
301-
}
302-
303-
def emitAggOpInfos(xb: CodeBuilder, value: Seq[(StorageAggOp, AggOpIndex)]): Unit =
304-
emitSeq(xb, value.map(v => xxb =>
305-
emitNew(xb, classOf[(StorageAggOp, AggOpIndex)], xxxb =>
306-
emitStorageAggOp(xxxb, v._1)
307-
emitAggOpIndex(xxxb, v._2)
308-
)
309-
))
310-
311-
def emitGroupingJoinIndexes(xb: CodeBuilder, value: GroupingJoinIndexes): Unit =
312-
emitNew(xb, classOf[GroupingJoinIndexes], xxb =>
313-
emitVarIndexes(xxb, value.varIndexes)
314-
emitConstIndexes(xxb, value.constIndexes)
315-
emitSeqInt(xxb, value.groupingIndexes)
316-
emitAggOpInfos(xxb, value.aggOpInfos))
317-
318281
val CD_BoxedUnit = clsDesc(classOf[scala.runtime.BoxedUnit])
319282

320283
/** Emit `BoxedUnit.UNIT`. */

src/main/scala/datalog/execution/ExecutionEngine.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ trait ExecutionEngine {
2929
* @param rule - Includes the head at idx 0
3030
*/
3131
inline def getOperatorKey(rule: Seq[Atom]): JoinIndexes =
32-
JoinIndexes(rule, None, None)
32+
JoinIndexes(rule, None)
3333

3434
def getOperatorKeys(rId: RelationId): mutable.ArrayBuffer[JoinIndexes] = {
3535
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]())

src/main/scala/datalog/execution/JoinIndexes.scala

+11-58
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package datalog.execution
22

3-
import datalog.dsl.{Atom, Constant, Variable, GroupingAtom, AggOp}
3+
import datalog.dsl.{Atom, Constant, Variable}
44
import datalog.execution.ir.{IROp, ProjectJoinFilterOp, ScanOp}
5-
import datalog.storage.{DB, EDB, NS, RelationId, StorageManager, StorageAggOp}
5+
import datalog.storage.{DB, EDB, NS, RelationId, StorageManager}
66
import datalog.tools.Debug.debug
77

88
import scala.collection.mutable
@@ -12,19 +12,7 @@ import scala.reflect.ClassTag
1212
type AllIndexes = mutable.Map[String, JoinIndexes]
1313

1414
enum PredicateType:
15-
case POSITIVE, NEGATED, GROUPING
16-
17-
18-
enum AggOpIndex:
19-
case LV(i: Int)
20-
case GV(i: Int)
21-
case C(c: Constant)
22-
23-
case class GroupingJoinIndexes(varIndexes: Seq[Seq[Int]],
24-
constIndexes: mutable.Map[Int, Constant],
25-
groupingIndexes: Seq[Int],
26-
aggOpInfos: Seq[(StorageAggOp, AggOpIndex)]
27-
)
15+
case POSITIVE, NEGATED
2816

2917
/**
3018
* Wrapper object for join keys for IDB rules
@@ -43,8 +31,7 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
4331
deps: Seq[(PredicateType, RelationId)],
4432
atoms: Seq[Atom],
4533
cxns: mutable.Map[String, mutable.Map[Int, Seq[String]]],
46-
edb: Boolean = false,
47-
groupingIndexes: Map[String, GroupingJoinIndexes] = Map.empty
34+
edb: Boolean = false
4835
) {
4936
override def toString(): String = ""//toStringWithNS(null)
5037

@@ -70,18 +57,13 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
7057
}
7158

7259
object JoinIndexes {
73-
def apply(rule: Seq[Atom], precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]],
74-
precalculatedGroupingIndexes: Option[Map[String, GroupingJoinIndexes]]) = {
60+
def apply(rule: Seq[Atom], precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]]) = {
7561
val constants = mutable.Map[Int, Constant]() // position => constant
7662
val variables = mutable.Map[Variable, Int]() // v.oid => position
7763

7864
val body = rule.drop(1)
7965

80-
val deps = body.map(a => (
81-
a match
82-
case _: GroupingAtom => PredicateType.GROUPING
83-
case _ => if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE
84-
, a.rId))
66+
val deps = body.map(a => (if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE, a.rId))
8567

8668
val typeHelper = body.flatMap(a => a.terms.map(* => !a.negated))
8769

@@ -137,36 +119,7 @@ object JoinIndexes {
137119
)).to(mutable.Map)
138120
)
139121

140-
//groupings
141-
val groupingIndexes = precalculatedGroupingIndexes.getOrElse(
142-
body.collect{ case ga: GroupingAtom => ga }.map(ga =>
143-
val (varsp, ctans) = ga.gp.terms.zipWithIndex.partitionMap{
144-
case (v: Variable, i) => Left((v, i))
145-
case (c: Constant, i) => Right((c, i))
146-
}
147-
val vars = varsp.filterNot(_._1.anon)
148-
val gis = ga.gv.map(v => vars.find(_._1 == v).get).map(_._2)
149-
ga.hash -> GroupingJoinIndexes(
150-
vars.groupBy(_._1).values.filter(_.size > 1).map(_.map(_._2)).toSeq,
151-
ctans.map(_.swap).to(mutable.Map),
152-
gis,
153-
ga.ags.map(_._1).map(ao =>
154-
val aoi = ao.t match
155-
case v: Variable =>
156-
val i = ga.gv.indexOf(v)
157-
if i >= 0 then AggOpIndex.GV(gis(i)) else AggOpIndex.LV(vars.find(_._1 == v).get._2)
158-
case c: Constant => AggOpIndex.C(c)
159-
ao match
160-
case AggOp.SUM(t) => (StorageAggOp.SUM, aoi)
161-
case AggOp.COUNT(t) => (StorageAggOp.COUNT, aoi)
162-
case AggOp.MIN(t) => (StorageAggOp.MIN, aoi)
163-
case AggOp.MAX(t) => (StorageAggOp.MAX, aoi)
164-
)
165-
)
166-
).toMap
167-
)
168-
169-
new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, edb = false, groupingIndexes = groupingIndexes)
122+
new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns)
170123
}
171124

172125
// used to approximate poor user-defined order
@@ -261,7 +214,7 @@ object JoinIndexes {
261214
presortSelect(sortBy, originalK, sm, -1)
262215
val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate(
263216
newHash,
264-
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes))
217+
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns))
265218
)
266219
(input.map(c => ProjectJoinFilterOp(rId, newK, newBody.map((_, oldP) => c.childrenSO(oldP)): _*)), newK)
267220
}
@@ -282,15 +235,15 @@ object JoinIndexes {
282235
presortSelect(sortBy, originalK, sm, deltaIdx)
283236
val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate(
284237
newHash,
285-
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes))
238+
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns))
286239
)
287240
(newK.atoms.drop(1).map(a => input(originalK.atoms.drop(1).indexOf(a))), newK)
288241
}
289242

290243
def allOrders(rule: Seq[Atom]): AllIndexes = {
291-
val idx = JoinIndexes(rule, None, None)
244+
val idx = JoinIndexes(rule, None)
292245
mutable.Map[String, JoinIndexes](rule.drop(1).permutations.map(r =>
293-
val toRet = JoinIndexes(rule.head +: r, Some(idx.cxns), Some(idx.groupingIndexes))
246+
val toRet = JoinIndexes(rule.head +: r, Some(idx.cxns))
294247
toRet.hash -> toRet
295248
).toSeq:_*)
296249
}

src/main/scala/datalog/execution/LambdaCompiler.scala

-4
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,4 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten
201201
val clhs = compile(children.head)
202202
val crhs = compile(children(1))
203203
sm => sm.diff(clhs(sm), crhs(sm))
204-
205-
case GroupingOp(child, gji) =>
206-
val clh = compile(child)
207-
sm => sm.groupingHelper(clh(sm), gji)
208204
}

src/main/scala/datalog/execution/NaiveExecutionEngine.scala

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package datalog.execution
22

3-
import datalog.dsl.{Atom, Constant, Term, Variable, GroupingAtom, AggOp}
4-
import datalog.storage.{RelationId, CollectionsStorageManager, StorageManager, EDB, StorageAggOp}
3+
import datalog.dsl.{Atom, Constant, Term, Variable}
4+
import datalog.storage.{RelationId, CollectionsStorageManager, StorageManager, EDB}
55
import datalog.tools.Debug.debug
66

77
import scala.collection.mutable
@@ -40,11 +40,6 @@ class NaiveExecutionEngine(val storageManager: StorageManager, stratified: Boole
4040
val jIdx = getOperatorKey(rule)
4141
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(jIdx)
4242
storageManager.addConstantsToDomain(jIdx.constIndexes.values.toSeq)
43-
44-
// We need to add the constants occurring in the grouping predicates of the grouping atoms
45-
rule.collect{ case ga: GroupingAtom => ga}.foreach(ga =>
46-
storageManager.addConstantsToDomain(jIdx.groupingIndexes(ga.hash).constIndexes.values.toSeq)
47-
)
4843
}
4944

5045
def insertEDB(rule: Atom): Unit = {

src/main/scala/datalog/execution/PrecedenceGraph.scala

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package datalog.execution
22

3-
import datalog.dsl.{Atom, GroupingAtom}
3+
import datalog.dsl.Atom
44
import datalog.tools.Debug.debug
55
import datalog.storage.{NS, RelationId}
66

@@ -34,9 +34,8 @@ class PrecedenceGraph(using ns: NS /* ns used for pretty printing */) {
3434

3535
def addNode(rule: Seq[Atom]): Unit = {
3636
idbs.addOne(rule.head.rId)
37-
def cond(a: Atom) = !a.negated && !a.isInstanceOf[GroupingAtom] // We treat grouping atoms as if they were negated because negation and aggregation require the same stratification
38-
adjacencyList.update(rule.head.rId, adjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(cond(_)).map(_.rId))
39-
negAdjacencyList.update(rule.head.rId, negAdjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(!cond(_)).map(_.rId))
37+
adjacencyList.update(rule.head.rId, adjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(!_.negated).map(_.rId))
38+
negAdjacencyList.update(rule.head.rId, negAdjacencyList.getOrElse(rule.head.rId, Seq.empty) ++ rule.drop(1).filter(_.negated).map(_.rId))
4039
}
4140

4241
def updateNodeAlias(aliases: mutable.Map[RelationId, RelationId]): Unit = {
@@ -132,11 +131,11 @@ class PrecedenceGraph(using ns: NS /* ns used for pretty printing */) {
132131

133132
val result = sorted.map(_.toSet).toSeq
134133

135-
// check for negative or grouping cycle (since both require the same check we treat grouping atoms as negated atoms in the dependency graph)
134+
// check for negative cycle
136135
result.foreach(strata =>
137136
strata.foreach(p =>
138137
if (graph(p).negEdges.map(n => n.rId).intersect(strata).nonEmpty)
139-
throw new Exception("Negative or grouping cycle detected in input program")
138+
throw new Exception("Negative cycle detected in input program")
140139
)
141140
)
142141

@@ -164,7 +163,7 @@ class PrecedenceGraph(using ns: NS /* ns used for pretty printing */) {
164163
)
165164
)
166165
if (stratum.nonEmpty && stratum.values.max > stratum.keys.size)
167-
throw new Exception("Negative or grouping cycle detected in input program")
166+
throw new Exception("Negative cycle detected in input program")
168167
setDiff = prevStratum != stratum
169168
prevStratum = stratum.clone
170169
}

src/main/scala/datalog/execution/QuoteCompiler.scala

-39
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
5858
x match
5959
case PredicateType.POSITIVE => '{ PredicateType.POSITIVE }
6060
case PredicateType.NEGATED => '{ PredicateType.NEGATED }
61-
case PredicateType.GROUPING => '{ PredicateType.GROUPING }
6261
}
6362
}
6463

@@ -78,40 +77,6 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
7877
}
7978
}
8079

81-
82-
given ToExpr[StorageAggOp] with {
83-
def apply(x: StorageAggOp)(using Quotes) = {
84-
x match
85-
case StorageAggOp.SUM => '{ StorageAggOp.SUM }
86-
case StorageAggOp.COUNT => '{ StorageAggOp.COUNT }
87-
case StorageAggOp.MIN => '{ StorageAggOp.MIN }
88-
case StorageAggOp.MAX => '{ StorageAggOp.MAX }
89-
}
90-
}
91-
92-
given ToExpr[AggOpIndex] with {
93-
def apply(x: AggOpIndex)(using Quotes) = {
94-
x match
95-
case AggOpIndex.LV(i) => '{ AggOpIndex.LV(${ Expr(i) }) }
96-
case AggOpIndex.GV(i) => '{ AggOpIndex.GV(${ Expr(i) }) }
97-
case AggOpIndex.C(c) => '{ AggOpIndex.C(${ Expr(c) }) }
98-
99-
}
100-
}
101-
102-
given ToExpr[GroupingJoinIndexes] with {
103-
def apply(x: GroupingJoinIndexes)(using Quotes) = {
104-
'{
105-
GroupingJoinIndexes(
106-
${ Expr(x.varIndexes) },
107-
${ Expr(x.constIndexes) },
108-
${ Expr(x.groupingIndexes) },
109-
${ Expr(x.aggOpInfos) }
110-
)
111-
}
112-
}
113-
}
114-
11580
/**
11681
* Compiles a relational operator into a quote that returns an EDB. Future TODO: merge with compileIR when dotty supports.
11782
*/
@@ -200,10 +165,6 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend
200165
val crhs = compileIRRelOp(children(1))
201166
'{ $stagedSM.diff($clhs, $crhs) }
202167

203-
case GroupingOp(child, gji) =>
204-
val clh = compileIRRelOp(child)
205-
'{ $stagedSM.groupingHelper($clh, ${ Expr(gji) }) }
206-
207168
case DebugPeek(prefix, msg, children: _*) =>
208169
val res = compileIRRelOp(children.head)
209170
'{ debug(${ Expr(prefix) }, () => s"${${ Expr(msg()) }}") ; $res }

src/main/scala/datalog/execution/SemiNaiveExecutionEngine.scala

-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,6 @@ class SemiNaiveExecutionEngine(override val storageManager: StorageManager, stra
9898
innerSolve(toSolve, relations.toSeq)
9999
if (idx < strata.length - 1)
100100
storageManager.updateDiscovered()
101-
storageManager.clearKnownDelta() // KnownDelta is not always empty when moving to the next stratum,
102-
// it is a problem for aggregates because it leads to the creation of unwanted tuples.
103101
)
104102
storageManager.getNewIDBResult(toSolve)
105103
}

0 commit comments

Comments
 (0)