Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cspa parallel #41

Merged
merged 5 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/main/scala/datalog/execution/LambdaCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@ import org.glavo.classfile.CodeBuilder
import java.lang.invoke.MethodType
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.{immutable, mutable}
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.quoted.*
import scala.reflect.{classTag, ClassTag}

/**
* Separate out compile logic from StagedExecutionEngine
*/
class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) extends StagedCompiler(storageManager) {
given staging.Compiler = jitOptions.dotty
/** Convert a Seq of lambdas into a lambda returning a Seq. */
def seqToLambda[T](seq: Seq[StorageManager => T]): StorageManager => Seq[T] =
def seqToLambda[T](seq: Seq[StorageManager => T], inParallel: Boolean = false): StorageManager => Seq[T] =
if inParallel then
return sm => IROp.runFns(sm, seq, inParallel)(using classTag[AnyRef].asInstanceOf[ClassTag[T]])
seq match
case seq: immutable.ArraySeq.ofRef[_] =>
val arr = unsafeArrayToLambda(seq.unsafeArray)
Expand Down Expand Up @@ -90,7 +95,11 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten

case SequenceOp(label, children:_*) =>
val cOps: Array[CompiledFn[Any]] = children.map(compile).toArray
cOps.length match
assert(false, "This is never triggered")
if irTree.runInParallel then
// TODO: optimize by directly using the underlying Java stuff.
sm => IROp.runFns(sm, immutable.ArraySeq.unsafeWrapArray(cOps), inParallel = true)
else cOps.length match
case 1 =>
cOps(0)
case 2 =>
Expand Down Expand Up @@ -172,11 +181,11 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten
else
(children, k)

val compiledOps = seqToLambda(sortedChildren.map(compile))
val compiledOps = seqToLambda(sortedChildren.map(compile), inParallel = irTree.runInParallel)
sm => sm.union(compiledOps(sm))

case UnionOp(label, children: _*) =>
val compiledOps = seqToLambda(children.map(compile))
val compiledOps = seqToLambda(children.map(compile), inParallel = irTree.runInParallel)
sm => sm.union(compiledOps(sm))

case DiffOp(children: _*) =>
Expand Down
47 changes: 38 additions & 9 deletions src/main/scala/datalog/execution/ir/IROp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import datalog.tools.Debug
import datalog.tools.Debug.debug

import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
import scala.collection.{immutable, mutable}
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.reflect.ClassTag
import scala.quoted.*
import scala.util.{Failure, Success}
Expand Down Expand Up @@ -40,6 +40,9 @@ abstract class IROp[T](val children: IROp[T]*)(using val jitOptions: JITOptions,
var blockingCompiledFn: CompiledFn[T] = null // for when we're blocking and not ahead-of-time, so might as well skip the future
var compiledSnippetContinuationFn: (StorageManager, Seq[StorageManager => T]) => T = null

/** Should the children of this op be run in parallel? */
val runInParallel: Boolean = false

/**
* Add continuation to revert control flow to the interpret method, which checks for optimizations/deoptimizations
*/
Expand All @@ -51,7 +54,26 @@ abstract class IROp[T](val children: IROp[T]*)(using val jitOptions: JITOptions,
*/
def run(storageManager: StorageManager): T =
throw new Exception(s"Error: calling run on likely rel op: $code")

}
object IROp {
given ExecutionContext = ExecutionContext.global
def runFns[T: ClassTag](storageManager: StorageManager, seq: Seq[StorageManager => T], inParallel: Boolean = false): Seq[T] =
if seq.length == 1 then
return immutable.ArraySeq.unsafeWrapArray(Array(seq.head(storageManager)))
if inParallel == false then
return seq.map(_(storageManager))
val futures = immutable.ArraySeq.newBuilder[Future[T]]
futures.sizeHint(seq.length)
// Spawn threads for the N - 1 first children
seq.view.init.foreach: op =>
futures += Future(op(storageManager))
// Run the last child on the current thread.
val last = seq.last(storageManager)
futures += Future(last)
Await.result(Future.sequence(futures.result()), Duration.Inf)
}
import IROp.*

/**
* @param children: SequenceOp[SequenceOp.NaiveEval, DoWhileOp]
Expand Down Expand Up @@ -129,10 +151,12 @@ case class DoWhileOp(toCmp: DB, override val children:IROp[Any]*)(using JITOptio
* @param children: [Any*]
*/
case class SequenceOp(override val code: OpCode, override val children:IROp[Any]*)(using JITOptions) extends IROp[Any](children:_*) {
override val runInParallel: Boolean = code == OpCode.EVAL_SN

override def run_continuation(storageManager: StorageManager, opFns: Seq[CompiledFn[Any]]): Any =
opFns.map(o => o(storageManager))
runFns(storageManager, opFns, inParallel = runInParallel)
override def run(storageManager: StorageManager): Any =
children.map(o => o.run(storageManager))
runFns(storageManager, children.map(_.run), inParallel = runInParallel)
}

case class UpdateDiscoveredOp()(using JITOptions) extends IROp[Any] {
Expand Down Expand Up @@ -255,10 +279,12 @@ case class UnionOp(override val code: OpCode, override val children:IROp[EDB]*)(
var compiledFnIndexed: Future[CompiledFnIndexed[EDB]] = null
var blockingCompiledFnIndexed: CompiledFnIndexed[EDB] = null

override val runInParallel = true

override def run_continuation(storageManager: StorageManager, opFns: Seq[CompiledFn[EDB]]): EDB =
storageManager.union(opFns.map(o => o(storageManager)))
storageManager.union(runFns(storageManager, opFns, inParallel = runInParallel))
override def run(storageManager: StorageManager): EDB =
storageManager.union(children.map(o => o.run(storageManager)))
storageManager.union(runFns(storageManager, children.map(_.run), inParallel = runInParallel))
}

/**
Expand All @@ -271,8 +297,11 @@ case class UnionSPJOp(rId: RelationId, var k: JoinIndexes, override val children
var compiledFnIndexed: Future[CompiledFnIndexed[EDB]] = null
// var compiledFnIndexed: java.util.concurrent.Future[CompiledFnIndexed[EDB]] = null
// for now not filled out bc not planning on compiling higher than this

override val runInParallel = true

override def run_continuation(storageManager: StorageManager, opFns: Seq[CompiledFn[EDB]]): EDB =
storageManager.union(opFns.map(o => o(storageManager)))
storageManager.union(runFns(storageManager, opFns, inParallel = runInParallel))

override def run(storageManager: StorageManager): EDB =

Expand All @@ -289,7 +318,7 @@ case class UnionSPJOp(rId: RelationId, var k: JoinIndexes, override val children
// )
// TODO: change children.length from 3
if (jitOptions.sortOrder == SortOrder.Unordered || jitOptions.sortOrder == SortOrder.Badluck || children.length < 3 || jitOptions.granularity.flag != OpCode.OTHER) // If not only interpreting, then don't optimize since we are waiting for the optimized version to compile
storageManager.union(children.map((s: ProjectJoinFilterOp) => s.run(storageManager)))
storageManager.union(runFns(storageManager, children.map(_.run), inParallel = runInParallel))
else
val (sortedChildren, newK) = JoinIndexes.getPresort(
children,
Expand All @@ -298,7 +327,7 @@ case class UnionSPJOp(rId: RelationId, var k: JoinIndexes, override val children
k,
storageManager
)
storageManager.union(sortedChildren.map((s: ProjectJoinFilterOp) => s.run(storageManager)))
storageManager.union(runFns(storageManager, sortedChildren.map(_.run), inParallel = runInParallel))
}
/**
* @param children: [Union|Scan, Scan]
Expand Down
10 changes: 7 additions & 3 deletions src/main/scala/datalog/execution/ir/IRTreeGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) {
.map(r =>
val res = semiNaiveEvalRule(ruleMap(r))
ResetDeltaOp(r, res.asInstanceOf[IROp[Any]])
) :+ InsertDeltaNewIntoDerived()
)

SequenceOp(
OpCode.EVAL_SN,
queries:_*,
OpCode.SEQ,
SequenceOp(
OpCode.EVAL_SN,
queries:_*,
),
InsertDeltaNewIntoDerived()
)
}

Expand Down
Loading