diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index 691e67e228ef..705ac6307dc8 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -327,7 +327,7 @@ class Objects(using Context @constructorOnly): object Env: abstract class Data: private[Env] def getVal(x: Symbol)(using Context): Option[Value] - private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] + private[Env] def getVar(x: Symbol)(using Context): Option[Value] def widen(height: Int)(using Context): Data @@ -335,13 +335,15 @@ class Objects(using Context @constructorOnly): def show(using Context): String + def owner: ClassSymbol + /** Local environments can be deeply nested, therefore we need `outer`. * * For local variables in rhs of class field definitions, the `meth` is the primary constructor. */ private case class LocalEnv - (private[Env] val params: Map[Symbol, Value], meth: Symbol, outer: Data) - (valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Heap.Addr]) + (private[Env] val params: Map[Symbol, Value], meth: Symbol, outer: Data, owner: ClassSymbol) + (valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Value]) (using Context) extends Data: val level = outer.level + 1 @@ -350,17 +352,24 @@ class Objects(using Context @constructorOnly): report.warning("[Internal error] Deeply nested environment, level = " + level + ", " + meth.show + " in " + meth.enclosingClass.show, meth.defTree) private[Env] val vals: mutable.Map[Symbol, Value] = valsMap - private[Env] val vars: mutable.Map[Symbol, Heap.Addr] = varsMap + private[Env] val vars: mutable.Map[Symbol, Value] = varsMap private[Env] def getVal(x: Symbol)(using Context): Option[Value] = if x.is(Flags.Param) then params.get(x) else vals.get(x) - private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] = + private[Env] def getVar(x: Symbol)(using Context): Option[Value] = vars.get(x) + private[Env] def writeJoin(x: Symbol, value: Value)(using Context): Unit = + assert(vars.contains(x), "Variable not found " + x.show) + val current = vars.get(x).get + val value2 = value.join(current) + if value2 != current then + vars.update(x, value2) + def widen(height: Int)(using Context): Data = - new LocalEnv(params.map(_ -> _.widen(height)), meth, outer.widen(height))(this.vals, this.vars) + new LocalEnv(params.map(_ -> _.widen(height)), meth, outer.widen(height), owner)(this.vals, this.vars) def show(using Context) = "owner: " + meth.show + "\n" + @@ -377,12 +386,15 @@ class Objects(using Context @constructorOnly): private[Env] def getVal(x: Symbol)(using Context): Option[Value] = throw new RuntimeException("Invalid usage of non-existent env") - private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] = + private[Env] def getVar(x: Symbol)(using Context): Option[Value] = throw new RuntimeException("Invalid usage of non-existent env") def widen(height: Int)(using Context): Data = this def show(using Context): String = "NoEnv" + + def owner: ClassSymbol = + throw new RuntimeException("Invalid usage of non-existent env") end NoEnv /** An empty environment can be used for non-method environments, e.g., field initializers. @@ -390,8 +402,8 @@ class Objects(using Context @constructorOnly): * The owner for the local environment for field initializers is the primary constructor of the * enclosing class. */ - def emptyEnv(meth: Symbol)(using Context): Data = - new LocalEnv(Map.empty, meth, NoEnv)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty) + def emptyEnv(meth: Symbol)(using Context, State.Data): Data = + new LocalEnv(Map.empty, meth, NoEnv, State.currentObject)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty) def valValue(x: Symbol)(using data: Data, ctx: Context, trace: Trace): Value = data.getVal(x) match @@ -403,13 +415,13 @@ class Objects(using Context @constructorOnly): def getVal(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.getVal(x) - def getVar(x: Symbol)(using data: Data, ctx: Context): Option[Heap.Addr] = data.getVar(x) + def getVar(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.getVar(x) - def of(ddef: DefDef, args: List[Value], outer: Data)(using Context): Data = + def of(ddef: DefDef, args: List[Value], outer: Data)(using Context, State.Data): Data = val params = ddef.termParamss.flatten.map(_.symbol) assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size) assert(ddef.symbol.owner.isClass ^ (outer != NoEnv), "ddef.owner = " + ddef.symbol.owner.show + ", outer = " + outer + ", " + ddef.source) - new LocalEnv(params.zip(args).toMap, ddef.symbol, outer)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty) + new LocalEnv(params.zip(args).toMap, ddef.symbol, outer, State.currentObject)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty) def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit = assert(!x.isOneOf(Flags.Param | Flags.Mutable), "Only local immutable variable allowed") @@ -420,12 +432,20 @@ class Objects(using Context @constructorOnly): case _ => throw new RuntimeException("Incorrect local environment for initializing " + x.show) - def setLocalVar(x: Symbol, addr: Heap.Addr)(using data: Data, ctx: Context): Unit = + def setLocalVar(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit = assert(x.is(Flags.Mutable, butNot = Flags.Param), "Only local mutable variable allowed") data match case localEnv: LocalEnv => assert(!localEnv.vars.contains(x), "Already initialized local " + x.show) - localEnv.vars(x) = addr + localEnv.vars(x) = value + case _ => + throw new RuntimeException("Incorrect local environment for initializing " + x.show) + + def assignLocalVar(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit = + assert(x.is(Flags.Mutable, butNot = Flags.Param), "Only local mutable variable allowed") + data match + case localEnv: LocalEnv => + localEnv.writeJoin(x, value) case _ => throw new RuntimeException("Incorrect local environment for initializing " + x.show) @@ -672,12 +692,12 @@ class Objects(using Context @constructorOnly): if arr.addr.owner == State.currentObject then Heap.read(arr.addr) else - errorReadOtherStaticObject(State.currentObject, arr.addr) + errorReadOtherStaticObject(State.currentObject, arr.addr.owner, arr.addr.getTrace) Bottom else if target == defn.Array_update then assert(args.size == 2, "Incorrect number of arguments for Array update, found = " + args.size) if arr.addr.owner != State.currentObject then - errorMutateOtherStaticObject(State.currentObject, arr.addr) + errorMutateOtherStaticObject(State.currentObject, arr.addr.owner, arr.addr.getTrace) else Heap.writeJoin(arr.addr, args.tail.head.value) Bottom @@ -832,7 +852,7 @@ class Objects(using Context @constructorOnly): if addr.owner == State.currentObject then Heap.read(addr) else - errorReadOtherStaticObject(State.currentObject, addr) + errorReadOtherStaticObject(State.currentObject, addr.owner, addr.getTrace) Bottom else if ref.isObjectRef && ref.klass.hasSource then report.warning("Access uninitialized field " + field.show + ". " + Trace.show, Trace.position) @@ -901,7 +921,7 @@ class Objects(using Context @constructorOnly): if ref.hasVar(field) then val addr = ref.varAddr(field) if addr.owner != State.currentObject then - errorMutateOtherStaticObject(State.currentObject, addr) + errorMutateOtherStaticObject(State.currentObject, addr.owner, addr.getTrace) else Heap.writeJoin(addr, rhs) else @@ -968,9 +988,7 @@ class Objects(using Context @constructorOnly): */ def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) { if sym.is(Flags.Mutable) then - val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject) - Env.setLocalVar(sym, addr) - Heap.writeJoin(addr, value) + Env.setLocalVar(sym, value) else Env.setLocalVal(sym, value) } @@ -988,11 +1006,11 @@ class Objects(using Context @constructorOnly): // Assume forward reference check is doing a good job given Env.Data = env Env.getVar(sym) match - case Some(addr) => - if addr.owner == State.currentObject then - Heap.read(addr) + case Some(value) => + if env.owner == State.currentObject then + value else - errorReadOtherStaticObject(State.currentObject, addr) + errorReadOtherStaticObject(State.currentObject, env.owner, Trace.empty) Bottom end if case _ => @@ -1042,11 +1060,11 @@ class Objects(using Context @constructorOnly): case Some(thisV -> env) => given Env.Data = env Env.getVar(sym) match - case Some(addr) => - if addr.owner != State.currentObject then - errorMutateOtherStaticObject(State.currentObject, addr) + case Some(value) => + if env.owner != State.currentObject then + errorMutateOtherStaticObject(State.currentObject, env.owner, Trace.empty) else - Heap.writeJoin(addr, value) + Env.assignLocalVar(sym, value) case _ => report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". " + Trace.show, Trace.position) @@ -1802,25 +1820,21 @@ class Objects(using Context @constructorOnly): else "" val mutateErrorSet: mutable.Set[(ClassSymbol, ClassSymbol)] = mutable.Set.empty - def errorMutateOtherStaticObject(currentObj: ClassSymbol, addr: Heap.Addr)(using Trace, Context) = - val otherObj = addr.owner - val addr_trace = addr.getTrace + def errorMutateOtherStaticObject(currentObj: ClassSymbol, otherObj: ClassSymbol, trace: Trace)(using Trace, Context) = if mutateErrorSet.add((currentObj, otherObj)) then val msg = s"Mutating ${otherObj.show} during initialization of ${currentObj.show}.\n" + "Mutating other static objects during the initialization of one static object is forbidden. " + Trace.show + - printTraceWhenMultiple(addr_trace) + printTraceWhenMultiple(trace) report.warning(msg, Trace.position) val readErrorSet: mutable.Set[(ClassSymbol, ClassSymbol)] = mutable.Set.empty - def errorReadOtherStaticObject(currentObj: ClassSymbol, addr: Heap.Addr)(using Trace, Context) = - val otherObj = addr.owner - val addr_trace = addr.getTrace + def errorReadOtherStaticObject(currentObj: ClassSymbol, otherObj: ClassSymbol, trace: Trace)(using Trace, Context) = if readErrorSet.add((currentObj, otherObj)) then val msg = "Reading mutable state of " + otherObj.show + " during initialization of " + currentObj.show + ".\n" + "Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. " + Trace.show + - printTraceWhenMultiple(addr_trace) + printTraceWhenMultiple(trace) report.warning(msg, Trace.position)