Skip to content

Commit

Permalink
Make sure symbols in annotation trees are fresh before pickling (#22002)
Browse files Browse the repository at this point in the history
In a nutshell: when mapping annotated types, we can currently end up
with the same symbol being declared in distinct trees, which crashes the
pickler as it expects each symbol to be declared in a single place. See
#19957 (comment) and
#19957 (comment) for
more context.

This PR ensures that all symbols in annotation trees are different by
creating fresh symbols for all symbols in annotation tree during
`PostTyper`.

In my [previous
attempt](ab70f18)
which was discussed on #19957, I did it in `Annotations.mapWith`. Here,
it's only done once in `PostTyper`, so this is more lightweight.

Fixes #17939, fixes #19846 and fixes (partially?) #20272.
  • Loading branch information
odersky authored Dec 17, 2024
2 parents e52aea4 + ca3c797 commit ae80285
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 10 deletions.
42 changes: 32 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package dotty.tools
package dotc
package transform

import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar}
import dotty.tools.dotc.ast.{Trees, tpd, untpd, desugar, TreeTypeMap}
import scala.collection.mutable
import core.*
import dotty.tools.dotc.typer.Checking
Expand All @@ -16,7 +16,7 @@ import Symbols.*, NameOps.*
import ContextFunctionResults.annotateContextResults
import config.Printers.typr
import config.Feature
import util.SrcPos
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName
import cc.*
Expand Down Expand Up @@ -154,7 +154,21 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case _ =>
case _ =>

private def transformAnnot(annot: Tree)(using Context): Tree = {
/** Returns a copy of the given tree with all symbols fresh.
*
* Used to guarantee that no symbols are shared between trees in different
* annotations.
*/
private def copySymbols(tree: Tree)(using Context) =
Stats.trackTime("Annotations copySymbols"):
val ttm =
new TreeTypeMap:
override def withMappedSyms(syms: List[Symbol]) =
withMappedSyms(syms, mapSymbols(syms, this, true))
ttm(tree)

/** Transforms the given annotation tree. */
private def transformAnnotTree(annot: Tree)(using Context): Tree = {
val saved = inJavaAnnot
inJavaAnnot = annot.symbol.is(JavaDefined)
if (inJavaAnnot) checkValidJavaAnnotation(annot)
Expand All @@ -163,7 +177,19 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
}

private def transformAnnot(annot: Annotation)(using Context): Annotation =
annot.derivedAnnotation(transformAnnot(annot.tree))
val tree1 =
annot match
case _: BodyAnnotation => annot.tree
case _ => copySymbols(annot.tree)
annot.derivedAnnotation(transformAnnotTree(tree1))

/** Transforms all annotations in the given type. */
private def transformAnnotsIn(using Context) =
new TypeMap:
def apply(tp: Type) = tp match
case tp @ AnnotatedType(parent, annot) =>
tp.derivedAnnotatedType(mapOver(parent), transformAnnot(annot))
case _ => mapOver(tp)

private def processMemberDef(tree: Tree)(using Context): tree.type = {
val sym = tree.symbol
Expand Down Expand Up @@ -501,7 +527,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
Checking.checkRealizable(tree.tpt.tpe, tree.srcPos, "SAM type")
super.transform(tree)
case tree @ Annotated(annotated, annot) =>
cpy.Annotated(tree)(transform(annotated), transformAnnot(annot))
cpy.Annotated(tree)(transform(annotated), transformAnnotTree(annot))
case tree: AppliedTypeTree =>
if (tree.tpt.symbol == defn.andType)
Checking.checkNonCyclicInherited(tree.tpe, tree.args.tpes, EmptyScope, tree.srcPos)
Expand All @@ -524,11 +550,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
super.transform(tree)
case tree: TypeTree =>
val tpe = if tree.isInferred then CleanupRetains()(tree.tpe) else tree.tpe
tree.withType:
tpe match
case AnnotatedType(parent, annot) =>
AnnotatedType(parent, transformAnnot(annot)) // TODO: Also map annotations embedded in type?
case _ => tpe
tree.withType(transformAnnotsIn(tpe))
case Typed(Ident(nme.WILDCARD), _) =>
withMode(Mode.Pattern)(super.transform(tree))
// The added mode signals that bounds in a pattern need not
Expand Down
8 changes: 8 additions & 0 deletions tests/pos/annot-17939.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.annotation.Annotation
class myRefined[T](f: T => Boolean) extends Annotation

class Box[T](val x: T)
class Box2(val x: Int)

class A(a: String @myRefined((x: Int) => Box(3).x == 3)) // crash
class A2(a2: String @myRefined((x: Int) => Box2(3).x == 3)) // works
9 changes: 9 additions & 0 deletions tests/pos/annot-19846.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package dependentAnnotation

class lambdaAnnot(g: () => Int) extends annotation.StaticAnnotation

def f(x: Int): Int @lambdaAnnot(() => x + 1) = x

@main def main =
val y: Int = 5
val z = f(y)
8 changes: 8 additions & 0 deletions tests/pos/annot-19846b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class qualified[T](predicate: T => Boolean) extends annotation.StaticAnnotation

class EqualPair(val x: Int, val y: Int @qualified[Int](it => it == x))

@main def main =
val p = EqualPair(42, 42)
val y = p.y
println(42)
15 changes: 15 additions & 0 deletions tests/pos/annot-body.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// This test checks that symbols in `BodyAnnotation` are not copied in
// `transformAnnot` during `PostTyper`.

package json

trait Reads[A] {
def reads(a: Any): A
}

object JsMacroImpl {
inline def reads[A]: Reads[A] =
new Reads[A] { self =>
def reads(a: Any) = ???
}
}
20 changes: 20 additions & 0 deletions tests/pos/annot-i20272a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import language.experimental.captureChecking

trait Iterable[T] { self: Iterable[T]^ =>
def map[U](f: T => U): Iterable[U]^{this, f}
}

object Test {
def assertEquals[A, B](a: A, b: B): Boolean = ???

def foo[T](level: Int, lines: Iterable[T]) =
lines.map(x => x)

def bar(messages: Iterable[String]) =
foo(1, messages)

val it: Iterable[String] = ???
val msgs = bar(it)

assertEquals(msgs, msgs)
}

0 comments on commit ae80285

Please sign in to comment.