Skip to content

Commit

Permalink
Add footprint optimization based on separation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
liufengyun committed Aug 23, 2024
1 parent b4e9f8f commit 0db9c0a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 19 deletions.
102 changes: 88 additions & 14 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,38 @@ class Objects(using Context @constructorOnly):
* Represents a lambda expression
* @param klass The enclosing class of the anonymous function's creation site
*/
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement:
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data)(using @constructorOnly ctx: Context) extends ValueElement:
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")"

def flatten: Iterable[Value | Addr] = env.flatten ++ Vector(thisV)
val freeVals: Set[Symbol] = computeFreeVars()(using ctx)

def computeFreeVars()(using ctx: Context): Set[Symbol] =
// TODO: compute captures transitively
val refs = mutable.Set.empty[Symbol]
val defs = mutable.Set.empty[Symbol]
val traverser = new TreeTraverser:
def traverse(tree: Tree)(using Context) =
tree match
case ident: Ident =>
val sym = ident.symbol
if sym.isTerm && sym.isLocal then refs += sym

case vdef: ValDef =>
val sym = vdef.symbol
if sym.isLocal then defs += sym

case _ =>
traverseChildren(tree)
end traverse
traverser.traverse(code)
refs.diff(defs).toSet

def flatten: Iterable[Value | Addr] =
val captured = freeVals.flatMap: x =>
val resOpt = Env.get(x)(using env)
resOpt.map(_ :: Nil).getOrElse(Nil)

captured ++ Vector(thisV)

/**
* Represents a set of values
Expand Down Expand Up @@ -275,6 +303,8 @@ class Objects(using Context @constructorOnly):

def currentObject(using data: Data): ClassSymbol = data.checkingObjects.last.klass

def currentObjectRef(using data: Data): ObjectRef = data.checkingObjects.last

private def doCheckObject(classSym: ClassSymbol)(using ctx: Context, data: Data) =
val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]

Expand Down Expand Up @@ -427,9 +457,13 @@ class Objects(using Context @constructorOnly):
report.warning("[Internal error] Value not found " + x.show + "\nenv = " + data.show + ". " + Trace.show, Trace.position)
Bottom

def getVal(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.getVal(x)
def getVal(x: Symbol)(using data: Data): Option[Value] = data.getVal(x)

def getVar(x: Symbol)(using data: Data): Option[Heap.Addr] = data.getVar(x)

def getVar(x: Symbol)(using data: Data, ctx: Context): Option[Heap.Addr] = data.getVar(x)
def get(x: Symbol)(using data: Data): Option[Heap.Addr | Value] =
if x.is(Flags.Mutable) then data.getVar(x)
else data.getVal(x)

def of(ddef: DefDef, args: List[Value], outer: Data)(using Context): Data =
val params = ddef.termParamss.flatten.map(_.symbol)
Expand Down Expand Up @@ -512,7 +546,7 @@ class Objects(using Context @constructorOnly):
*
* TODO: speed up equality check for heap.
*/
opaque type Data = Map[Addr, Value]
type Data = Map[Addr, Value]

/** Store the heap as a mutable field to avoid threading it through the program. */
class MutableData(private[Heap] var heap: Data, private[Heap] var changeSet: Set[Addr]):
Expand Down Expand Up @@ -556,6 +590,43 @@ class Objects(using Context @constructorOnly):
mutable.heap = heap
mutable.changeSet = changeSet

/** Compute the footprint of the heap for evaluating an expression
*
* The regions of the heap not in the footprint do not matter for
* evaluating the underlying expression.
*
* The reasoning above is similar to the frame rule in separation logic.
*/
def footprint(heap: Data, thisV: Value, env: Env.Data, currentObj: ObjectRef): Data =
val toVisit = mutable.Queue.empty[Value]
val visited = mutable.Set.empty[Value]
val reachalbeKeys = mutable.Set.empty[Addr]

def visit(item: Value | Addr): Unit =
item match
case addr: Addr =>
reachalbeKeys += addr
val value = heap(addr)
if !visited.contains(value) then
toVisit += value

case value: Value =>
recur(value)

def recur(value: Value): Unit =
if !visited.contains(value) then
visited += value
for item <- value.flatten do visit(item)

toVisit += currentObj
toVisit += thisV
for item <- env.flatten do visit(item)

while toVisit.nonEmpty do
visit(toVisit.dequeue())

heap.filter((k, v) => reachalbeKeys.contains(k))

/** Perform garbage collection on the abstract heap.
*
* A heap address created after evaluating an expression can be reclaimed
Expand Down Expand Up @@ -602,30 +673,33 @@ class Objects(using Context @constructorOnly):
/** Cache used to terminate the check */
object Cache:
case class Config(thisV: Value, env: Env.Data, heap: Heap.Data)
case class Res(value: Value, heap: Heap.Data)
case class Res(value: Value, heap: Heap.Data, changeSet: Set[Addr])

class Data extends Cache[Config, Res]:
def get(thisV: Value, expr: Tree)(using Heap.MutableData, Env.Data): Option[Value] =
val config = Config(thisV, summon[Env.Data], Heap.getHeapData())
super.get(config, expr).map(_.value)

def cachedEval(thisV: ThisValue, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data): Value =
def cachedEval(thisV: ThisValue, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data, State.Data): Value =
val env = summon[Env.Data]
val config = Config(thisV, env, Heap.getHeapData())
val footprint = Heap.footprint(Heap.getHeapData(), thisV, env, State.currentObjectRef)
val config = Config(thisV, env, footprint)
val heapBefore = Heap.getHeapData()
val changeSetBefore = Heap.getChangeSet()
val result = super.cachedEval(config, expr, cacheResult, default = Res(Bottom, heapBefore)) { expr =>
Heap.update(heapBefore, changeSet = Set.empty)

Heap.update(footprint, changeSet = Set.empty)
val result = super.cachedEval(config, expr, cacheResult, default = Res(Bottom, footprint, Set.empty)) { expr =>
val value = fun(expr)
val heapAfter = Heap.getHeapData()
val changeSetNew = Heap.getChangeSet()
// Perform garbage collection
val heapGC =
if cacheResult then Heap.gc(value, heapBefore, heapAfter, changeSetNew)
if cacheResult then Heap.gc(value, footprint, heapAfter, changeSetNew)
else heapAfter
Heap.update(heapGC, changeSetNew ++ changeSetBefore)
Res(value, heapAfter)
Res(value, heapAfter, changeSetNew)
}
Heap.update(heapBefore ++ result.heap, changeSetBefore ++ result.changeSet)

result.value
end Cache

Expand Down Expand Up @@ -744,7 +818,7 @@ class Objects(using Context @constructorOnly):
* @param superType The type of the super in a super call. NoType for non-super calls.
* @param needResolve Whether the target of the call needs resolution?
*/
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show) + ", heap size = " + Heap.getHeapData().size, printer, (_: Value).show) {
value.filterClass(meth.owner) match
case Cold =>
report.warning("Using cold alias. " + Trace.show, Trace.position)
Expand Down
5 changes: 0 additions & 5 deletions compiler/src/dotty/tools/dotc/transform/init/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,3 @@ object Util:

// A concrete class may not be instantiated if the self type is not satisfied
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass

/** Whether the class or its super class/trait contains any mutable fields? */
def isMutable(cls: ClassSymbol)(using Context): Boolean =
cls.classInfo.decls.exists(_.is(Flags.Mutable)) ||
cls.parentSyms.exists(parentCls => isMutable(parentCls.asClass))

0 comments on commit 0db9c0a

Please sign in to comment.