Skip to content

Commit

Permalink
Records local var value in env instead of heap
Browse files Browse the repository at this point in the history
  • Loading branch information
EnzeXing committed Sep 6, 2024
1 parent 67cd3eb commit f97db49
Showing 1 changed file with 51 additions and 37 deletions.
88 changes: 51 additions & 37 deletions compiler/src/dotty/tools/dotc/transform/init/Objects.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,21 +327,23 @@ class Objects(using Context @constructorOnly):
object Env:
abstract class Data:
private[Env] def getVal(x: Symbol)(using Context): Option[Value]
private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr]
private[Env] def getVar(x: Symbol)(using Context): Option[Value]

def widen(height: Int)(using Context): Data

def level: Int

def show(using Context): String

def owner: ClassSymbol

/** Local environments can be deeply nested, therefore we need `outer`.
*
* For local variables in rhs of class field definitions, the `meth` is the primary constructor.
*/
private case class LocalEnv
(private[Env] val params: Map[Symbol, Value], meth: Symbol, outer: Data)
(valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Heap.Addr])
(private[Env] val params: Map[Symbol, Value], meth: Symbol, outer: Data, owner: ClassSymbol)
(valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Value])
(using Context)
extends Data:
val level = outer.level + 1
Expand All @@ -350,17 +352,24 @@ class Objects(using Context @constructorOnly):
report.warning("[Internal error] Deeply nested environment, level = " + level + ", " + meth.show + " in " + meth.enclosingClass.show, meth.defTree)

private[Env] val vals: mutable.Map[Symbol, Value] = valsMap
private[Env] val vars: mutable.Map[Symbol, Heap.Addr] = varsMap
private[Env] val vars: mutable.Map[Symbol, Value] = varsMap

private[Env] def getVal(x: Symbol)(using Context): Option[Value] =
if x.is(Flags.Param) then params.get(x)
else vals.get(x)

private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] =
private[Env] def getVar(x: Symbol)(using Context): Option[Value] =
vars.get(x)

private[Env] def writeJoin(x: Symbol, value: Value)(using Context): Unit =
assert(vars.contains(x), "Variable not found " + x.show)
val current = vars.get(x).get
val value2 = value.join(current)
if value2 != current then
vars.update(x, value2)

def widen(height: Int)(using Context): Data =
new LocalEnv(params.map(_ -> _.widen(height)), meth, outer.widen(height))(this.vals, this.vars)
new LocalEnv(params.map(_ -> _.widen(height)), meth, outer.widen(height), owner)(this.vals, this.vars)

def show(using Context) =
"owner: " + meth.show + "\n" +
Expand All @@ -377,21 +386,24 @@ class Objects(using Context @constructorOnly):
private[Env] def getVal(x: Symbol)(using Context): Option[Value] =
throw new RuntimeException("Invalid usage of non-existent env")

private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] =
private[Env] def getVar(x: Symbol)(using Context): Option[Value] =
throw new RuntimeException("Invalid usage of non-existent env")

def widen(height: Int)(using Context): Data = this

def show(using Context): String = "NoEnv"

def owner: ClassSymbol =
throw new RuntimeException("Invalid usage of non-existent env")
end NoEnv

/** An empty environment can be used for non-method environments, e.g., field initializers.
*
* The owner for the local environment for field initializers is the primary constructor of the
* enclosing class.
*/
def emptyEnv(meth: Symbol)(using Context): Data =
new LocalEnv(Map.empty, meth, NoEnv)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty)
def emptyEnv(meth: Symbol)(using Context, State.Data): Data =
new LocalEnv(Map.empty, meth, NoEnv, State.currentObject)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty)

def valValue(x: Symbol)(using data: Data, ctx: Context, trace: Trace): Value =
data.getVal(x) match
Expand All @@ -403,13 +415,13 @@ class Objects(using Context @constructorOnly):

def getVal(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.getVal(x)

def getVar(x: Symbol)(using data: Data, ctx: Context): Option[Heap.Addr] = data.getVar(x)
def getVar(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.getVar(x)

def of(ddef: DefDef, args: List[Value], outer: Data)(using Context): Data =
def of(ddef: DefDef, args: List[Value], outer: Data)(using Context, State.Data): Data =
val params = ddef.termParamss.flatten.map(_.symbol)
assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size)
assert(ddef.symbol.owner.isClass ^ (outer != NoEnv), "ddef.owner = " + ddef.symbol.owner.show + ", outer = " + outer + ", " + ddef.source)
new LocalEnv(params.zip(args).toMap, ddef.symbol, outer)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty)
new LocalEnv(params.zip(args).toMap, ddef.symbol, outer, State.currentObject)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty)

def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit =
assert(!x.isOneOf(Flags.Param | Flags.Mutable), "Only local immutable variable allowed")
Expand All @@ -420,12 +432,20 @@ class Objects(using Context @constructorOnly):
case _ =>
throw new RuntimeException("Incorrect local environment for initializing " + x.show)

def setLocalVar(x: Symbol, addr: Heap.Addr)(using data: Data, ctx: Context): Unit =
def setLocalVar(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit =
assert(x.is(Flags.Mutable, butNot = Flags.Param), "Only local mutable variable allowed")
data match
case localEnv: LocalEnv =>
assert(!localEnv.vars.contains(x), "Already initialized local " + x.show)
localEnv.vars(x) = addr
localEnv.vars(x) = value
case _ =>
throw new RuntimeException("Incorrect local environment for initializing " + x.show)

def assignLocalVar(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit =
assert(x.is(Flags.Mutable, butNot = Flags.Param), "Only local mutable variable allowed")
data match
case localEnv: LocalEnv =>
localEnv.writeJoin(x, value)
case _ =>
throw new RuntimeException("Incorrect local environment for initializing " + x.show)

Expand Down Expand Up @@ -672,12 +692,12 @@ class Objects(using Context @constructorOnly):
if arr.addr.owner == State.currentObject then
Heap.read(arr.addr)
else
errorReadOtherStaticObject(State.currentObject, arr.addr)
errorReadOtherStaticObject(State.currentObject, arr.addr.owner, arr.addr.getTrace)
Bottom
else if target == defn.Array_update then
assert(args.size == 2, "Incorrect number of arguments for Array update, found = " + args.size)
if arr.addr.owner != State.currentObject then
errorMutateOtherStaticObject(State.currentObject, arr.addr)
errorMutateOtherStaticObject(State.currentObject, arr.addr.owner, arr.addr.getTrace)
else
Heap.writeJoin(arr.addr, args.tail.head.value)
Bottom
Expand Down Expand Up @@ -832,7 +852,7 @@ class Objects(using Context @constructorOnly):
if addr.owner == State.currentObject then
Heap.read(addr)
else
errorReadOtherStaticObject(State.currentObject, addr)
errorReadOtherStaticObject(State.currentObject, addr.owner, addr.getTrace)
Bottom
else if ref.isObjectRef && ref.klass.hasSource then
report.warning("Access uninitialized field " + field.show + ". " + Trace.show, Trace.position)
Expand Down Expand Up @@ -901,7 +921,7 @@ class Objects(using Context @constructorOnly):
if ref.hasVar(field) then
val addr = ref.varAddr(field)
if addr.owner != State.currentObject then
errorMutateOtherStaticObject(State.currentObject, addr)
errorMutateOtherStaticObject(State.currentObject, addr.owner, addr.getTrace)
else
Heap.writeJoin(addr, rhs)
else
Expand Down Expand Up @@ -968,9 +988,7 @@ class Objects(using Context @constructorOnly):
*/
def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) {
if sym.is(Flags.Mutable) then
val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject)
Env.setLocalVar(sym, addr)
Heap.writeJoin(addr, value)
Env.setLocalVar(sym, value)
else
Env.setLocalVal(sym, value)
}
Expand All @@ -988,11 +1006,11 @@ class Objects(using Context @constructorOnly):
// Assume forward reference check is doing a good job
given Env.Data = env
Env.getVar(sym) match
case Some(addr) =>
if addr.owner == State.currentObject then
Heap.read(addr)
case Some(value) =>
if env.owner == State.currentObject then
value
else
errorReadOtherStaticObject(State.currentObject, addr)
errorReadOtherStaticObject(State.currentObject, env.owner, Trace.empty)
Bottom
end if
case _ =>
Expand Down Expand Up @@ -1042,11 +1060,11 @@ class Objects(using Context @constructorOnly):
case Some(thisV -> env) =>
given Env.Data = env
Env.getVar(sym) match
case Some(addr) =>
if addr.owner != State.currentObject then
errorMutateOtherStaticObject(State.currentObject, addr)
case Some(value) =>
if env.owner != State.currentObject then
errorMutateOtherStaticObject(State.currentObject, env.owner, Trace.empty)
else
Heap.writeJoin(addr, value)
Env.assignLocalVar(sym, value)
case _ =>
report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". " + Trace.show, Trace.position)

Expand Down Expand Up @@ -1802,25 +1820,21 @@ class Objects(using Context @constructorOnly):
else ""

val mutateErrorSet: mutable.Set[(ClassSymbol, ClassSymbol)] = mutable.Set.empty
def errorMutateOtherStaticObject(currentObj: ClassSymbol, addr: Heap.Addr)(using Trace, Context) =
val otherObj = addr.owner
val addr_trace = addr.getTrace
def errorMutateOtherStaticObject(currentObj: ClassSymbol, otherObj: ClassSymbol, trace: Trace)(using Trace, Context) =
if mutateErrorSet.add((currentObj, otherObj)) then
val msg =
s"Mutating ${otherObj.show} during initialization of ${currentObj.show}.\n" +
"Mutating other static objects during the initialization of one static object is forbidden. " + Trace.show +
printTraceWhenMultiple(addr_trace)
printTraceWhenMultiple(trace)

report.warning(msg, Trace.position)

val readErrorSet: mutable.Set[(ClassSymbol, ClassSymbol)] = mutable.Set.empty
def errorReadOtherStaticObject(currentObj: ClassSymbol, addr: Heap.Addr)(using Trace, Context) =
val otherObj = addr.owner
val addr_trace = addr.getTrace
def errorReadOtherStaticObject(currentObj: ClassSymbol, otherObj: ClassSymbol, trace: Trace)(using Trace, Context) =
if readErrorSet.add((currentObj, otherObj)) then
val msg =
"Reading mutable state of " + otherObj.show + " during initialization of " + currentObj.show + ".\n" +
"Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. " + Trace.show +
printTraceWhenMultiple(addr_trace)
printTraceWhenMultiple(trace)

report.warning(msg, Trace.position)

0 comments on commit f97db49

Please sign in to comment.