Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,26 +253,31 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {

def DefDef(sym: TermSymbol, rhs: Tree = EmptyTree)(using Context): DefDef =
ta.assignType(DefDef(sym, Function.const(rhs)), sym)

def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(using Context): DefDef =
DefDef(sym, null, rhsFn)

/** A DefDef with given method symbol `sym`.
* @annotationsFn A function (possibly null) from parameters to annotations
* @rhsFn A function from parameter references
* to the method's right-hand side.
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
* are freshly generated if `rawParamss` is empty.
* When freshly generated, a non-null annotationsFn is used to add annotations
*/
def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(using Context): DefDef =

def DefDef(sym: TermSymbol, annotationsFn: (TermName => List[Annotation]) | Null, rhsFn: List[List[Tree]] => Tree)(using Context): DefDef =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid null at an (internal) API boundary.

// Map method type `tp` with remaining parameters stored in rawParamss to
// final result type and all (given or synthesized) parameters
def recur(tp: Type, remaining: List[List[Symbol]]): (Type, List[List[Symbol]]) = tp match
def recur(tp: Type, afn: (TermName => List[Annotation]) | Null,
remaining: List[List[Symbol]]): (Type, List[List[Symbol]]) = tp match
case tp: PolyType =>
val (tparams: List[TypeSymbol], remaining1) = remaining match
val (tparams: List[TypeSymbol], afn1, remaining1) = remaining match
case tparams :: remaining1 =>
assert(tparams.hasSameLengthAs(tp.paramNames) && tparams.head.isType)
(tparams.asInstanceOf[List[TypeSymbol]], remaining1)
(tparams.asInstanceOf[List[TypeSymbol]], null, remaining1)
case nil =>
(newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), Nil)
val (rtp, paramss) = recur(tp.instantiate(tparams.map(_.typeRef)), remaining1)
(newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), null, Nil)
val (rtp, paramss) = recur(tp.instantiate(tparams.map(_.typeRef)), afn1, remaining1)
(rtp, tparams :: paramss)
case tp: MethodType =>
val isParamDependent = tp.isParamDependent
Expand All @@ -297,22 +302,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
else makeSym(origInfo)
end valueParam

val (vparams: List[TermSymbol], remaining1) =
if tp.paramNames.isEmpty then (Nil, remaining)
val (vparams: List[TermSymbol], afn1, remaining1) =
if tp.paramNames.isEmpty then (Nil, null, remaining)
else remaining match
case vparams :: remaining1 =>
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
(vparams.asInstanceOf[List[TermSymbol]], remaining1)
(vparams.asInstanceOf[List[TermSymbol]], null, remaining1)
case nil =>
(tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.paramErasureStatuses).map(valueParam), Nil)
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), remaining1)
val res = tp.paramNames.lazyZip(tp.paramInfos).lazyZip(tp.paramErasureStatuses).map(valueParam)
if afn != null then
res.lazyZip(tp.paramNames.map(afn)).foreach: (s, annots) =>
s.addAnnotations(annots)
(res, null, Nil)
val (rtp, paramss) = recur(tp.instantiate(vparams.map(_.termRef)), afn1, remaining1)
(rtp, vparams :: paramss)
case _ =>
assert(remaining.isEmpty)
(tp.widenExpr, Nil)
end recur

val (rtp, paramss) = recur(sym.info, sym.rawParamss)
assert(sym.rawParamss.isEmpty || annotationsFn == null)
val (rtp, paramss) = recur(sym.info, annotationsFn, sym.rawParamss)
DefDef(sym, paramss, rtp, rhsFn(paramss.nestedMap(ref)))
end DefDef

Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Mixin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import StdNames.*
import Names.*
import NameKinds.*
import NameOps.*
import Annotations.*
import Phases.erasurePhase
import ast.Trees.*

Expand Down Expand Up @@ -315,7 +316,16 @@ class Mixin extends MiniPhase with SymTransformer { thisPhase =>
for (meth <- mixin.info.decls.toList if needsMixinForwarder(meth))
yield {
util.Stats.record("mixin forwarders")
transformFollowing(DefDef(mkMixinForwarderSym(meth.asTerm), forwarderRhsFn(meth)))
//Whereas method annotations are set directly in mkMixForwarderSym from the meth.sym
//in the case of parameter annotations, we need to store these off here to pass into the DefDef
//to use during valueParm creation since the mkMixinForwarderSym result symbol denotation no longer
//have the parameter annotations.
val annotationss: Map[TermName, List[Annotation]] = (for
rawParams <- meth.asTerm.rawParamss
if rawParams.nonEmpty && !rawParams.head.isType
p <- rawParams
yield (p.termRef.name, p.annotations)).toMap.withDefaultValue(Nil)
transformFollowing(DefDef(mkMixinForwarderSym(meth.asTerm), annotationss, forwarderRhsFn(meth)))
}

def mkMixinForwarderSym(target: TermSymbol): TermSymbol =
Expand Down
9 changes: 9 additions & 0 deletions tests/run/i22991/Bar.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Blah extends scala.annotation.StaticAnnotation

trait Barly {
def bar[T](a: String, @Foo v: Int)(@Foo b: T, @Blah w: Int) = ()
}

class Bar extends Barly{
def bar2(@Foo v: Int) = ()
}
6 changes: 6 additions & 0 deletions tests/run/i22991/Foo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import java.lang.annotation.*;

@Retention(RetentionPolicy.RUNTIME)
public @interface Foo {
}

17 changes: 17 additions & 0 deletions tests/run/i22991/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

//Test java runtime reflection access to @Runtime annotations on method parameters.
object Test extends App:
val method: java.lang.reflect.Method = classOf[Bar].getMethod("bar", classOf[String], classOf[Int], classOf[Object], classOf[Int])
val annots: Array[Array[java.lang.annotation.Annotation]] = method.getParameterAnnotations()
assert(annots.length == 4)
assert(annots(0).length == 0)
assert(annots(1).length == 1)
assert(annots(1)(0).isInstanceOf[Foo])
assert(annots(2).length == 1)
assert(annots(2)(0).isInstanceOf[Foo])
assert(annots(3).length == 0)

val method2: java.lang.reflect.Method = classOf[Bar].getMethod("bar2", classOf[Int])
val annots2: Array[Array[java.lang.annotation.Annotation]] = method2.getParameterAnnotations()
assert(annots2.length == 1)
assert(annots2(0)(0).isInstanceOf[Foo])
Loading