Skip to content

Commit

Permalink
Add footprint optimization based on separation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
liufengyun committed Aug 23, 2024
1 parent b4e9f8f commit cce828b
Showing 1 changed file with 55 additions and 11 deletions.
66 changes: 55 additions & 11 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ class Objects(using Context @constructorOnly):
case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement:
def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")"

def flatten: Iterable[Value | Addr] = env.flatten ++ Vector(thisV)
def flatten: Iterable[Value | Addr] =
// TODO: compute free local variables
Vector(thisV)

/**
* Represents a set of values
Expand Down Expand Up @@ -275,6 +277,8 @@ class Objects(using Context @constructorOnly):

def currentObject(using data: Data): ClassSymbol = data.checkingObjects.last.klass

def currentObjectRef(using data: Data): ObjectRef = data.checkingObjects.last

private def doCheckObject(classSym: ClassSymbol)(using ctx: Context, data: Data) =
val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]

Expand Down Expand Up @@ -512,7 +516,7 @@ class Objects(using Context @constructorOnly):
*
* TODO: speed up equality check for heap.
*/
opaque type Data = Map[Addr, Value]
type Data = Map[Addr, Value]

/** Store the heap as a mutable field to avoid threading it through the program. */
class MutableData(private[Heap] var heap: Data, private[Heap] var changeSet: Set[Addr]):
Expand Down Expand Up @@ -556,6 +560,43 @@ class Objects(using Context @constructorOnly):
mutable.heap = heap
mutable.changeSet = changeSet

/** Compute the footprint of the heap for evaluating an expression
*
* The regions of the heap not in the footprint do not matter for
* evaluating the underlying expression.
*
* The reasoning above is similar to the frame rule in separation logic.
*/
def footprint(heap: Data, thisV: Value, env: Env.Data, currentObj: ObjectRef): Data =
val toVisit = mutable.Queue.empty[Value]
val visited = mutable.Set.empty[Value]
val reachalbeKeys = mutable.Set.empty[Addr]

def visit(item: Value | Addr): Unit =
item match
case addr: Addr =>
reachalbeKeys += addr
val value = heap(addr)
if !visited.contains(value) then
toVisit += value

case value: Value =>
recur(value)

def recur(value: Value): Unit =
if !visited.contains(value) then
visited += value
for item <- value.flatten do visit(item)

toVisit += currentObj
toVisit += thisV
for item <- env.flatten do visit(item)

while toVisit.nonEmpty do
visit(toVisit.dequeue())

heap.filter((k, v) => reachalbeKeys.contains(k))

/** Perform garbage collection on the abstract heap.
*
* A heap address created after evaluating an expression can be reclaimed
Expand Down Expand Up @@ -602,30 +643,33 @@ class Objects(using Context @constructorOnly):
/** Cache used to terminate the check */
object Cache:
case class Config(thisV: Value, env: Env.Data, heap: Heap.Data)
case class Res(value: Value, heap: Heap.Data)
case class Res(value: Value, heap: Heap.Data, changeSet: Set[Addr])

class Data extends Cache[Config, Res]:
def get(thisV: Value, expr: Tree)(using Heap.MutableData, Env.Data): Option[Value] =
val config = Config(thisV, summon[Env.Data], Heap.getHeapData())
super.get(config, expr).map(_.value)

def cachedEval(thisV: ThisValue, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data): Value =
def cachedEval(thisV: ThisValue, expr: Tree, cacheResult: Boolean)(fun: Tree => Value)(using Heap.MutableData, Env.Data, State.Data): Value =
val env = summon[Env.Data]
val config = Config(thisV, env, Heap.getHeapData())
val footprint = Heap.footprint(Heap.getHeapData(), thisV, env, State.currentObjectRef)
val config = Config(thisV, env, footprint)
val heapBefore = Heap.getHeapData()
val changeSetBefore = Heap.getChangeSet()
val result = super.cachedEval(config, expr, cacheResult, default = Res(Bottom, heapBefore)) { expr =>
Heap.update(heapBefore, changeSet = Set.empty)

Heap.update(footprint, changeSet = Set.empty)
val result = super.cachedEval(config, expr, cacheResult, default = Res(Bottom, footprint, Set.empty)) { expr =>
val value = fun(expr)
val heapAfter = Heap.getHeapData()
val changeSetNew = Heap.getChangeSet()
// Perform garbage collection
val heapGC =
if cacheResult then Heap.gc(value, heapBefore, heapAfter, changeSetNew)
if cacheResult then Heap.gc(value, footprint, heapAfter, changeSetNew)
else heapAfter
Heap.update(heapGC, changeSetNew ++ changeSetBefore)
Res(value, heapAfter)
Res(value, heapAfter, changeSetNew)
}
Heap.update(heapBefore ++ result.heap, changeSetBefore ++ result.changeSet)

result.value
end Cache

Expand Down Expand Up @@ -744,7 +788,7 @@ class Objects(using Context @constructorOnly):
* @param superType The type of the super in a super call. NoType for non-super calls.
* @param needResolve Whether the target of the call needs resolution?
*/
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) {
def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show) + ", heap size = " + Heap.getHeapData().size, printer, (_: Value).show) {
value.filterClass(meth.owner) match
case Cold =>
report.warning("Using cold alias. " + Trace.show, Trace.position)
Expand Down

0 comments on commit cce828b

Please sign in to comment.