Skip to content

Commit

Permalink
WIP: copy phase and unroll annot
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed Aug 21, 2024
1 parent b64afad commit 4c9c78c
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Compiler {
List(new InstrumentCoverage) :: // Perform instrumentation for code coverage (if -coverage-out is set)
List(new CrossVersionChecks, // Check issues related to deprecated and experimental
new FirstTransform, // Some transformations to put trees into a canonical form
new UnrollDefs, // Unroll annotated methods
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
new ElimPackagePrefixes, // Eliminate references to package prefixes in Select nodes
new CookComments, // Cook the comments: expand variables, doc, etc.
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ class Definitions {
@tu lazy val MigrationAnnot: ClassSymbol = requiredClass("scala.annotation.migration")
@tu lazy val NowarnAnnot: ClassSymbol = requiredClass("scala.annotation.nowarn")
@tu lazy val UnusedAnnot: ClassSymbol = requiredClass("scala.annotation.unused")
@tu lazy val UnrollAnnot: ClassSymbol = requiredClass("scala.annotation.unroll") // TODO: probably should be getClassIfDefined
@tu lazy val TransparentTraitAnnot: ClassSymbol = requiredClass("scala.annotation.transparentTrait")
@tu lazy val NativeAnnot: ClassSymbol = requiredClass("scala.native")
@tu lazy val RepeatedAnnot: ClassSymbol = requiredClass("scala.annotation.internal.Repeated")
Expand Down
278 changes: 278 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/UnrollDefs.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
package dotty.tools.dotc.transform

import dotty.tools.dotc.*
import core.*
import MegaPhase.MiniPhase
import Contexts.*
import Symbols.*
import Flags.*
import SymDenotations.*
import Decorators.*
import ast.Trees.*
import ast.tpd
import StdNames.nme
import Names.*
import Constants.Constant
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
import dotty.tools.dotc.core.Types.{MethodType, NamedType, PolyType, Type}
import dotty.tools.dotc.core.Symbols

import scala.language.implicitConversions

class UnrollDefs extends MiniPhase {
import tpd._

val phaseName = "unroll"

override val runsAfter = Set(FirstTransform.name)

def copyParam(p: ValDef, parent: Symbol)(using Context) = {
implicitly[Context].typeAssigner.assignType(
cpy.ValDef(p)(p.name, p.tpt, p.rhs),
Symbols.newSymbol(parent, p.name, p.symbol.flags, p.symbol.info)
)
}

def copyParam2(p: TypeDef, parent: Symbol)(using Context) = {
implicitly[Context].typeAssigner.assignType(
cpy.TypeDef(p)(p.name, p.rhs),
Symbols.newSymbol(parent, p.name, p.symbol.flags, p.symbol.info)
)
}

def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
params
.zipWithIndex
.collect {
case (v, i) if v.annotations.exists(_.symbol.fullName.toString == "scala.annotation.unroll") =>
i
}
}
def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
def generateSingleForwarder(defdef: DefDef,
prevMethodType: Type,
paramIndex: Int,
nextParamIndex: Int,
nextSymbol: Symbol,
annotatedParamListIndex: Int,
paramLists: List[ParamClause],
isCaseApply: Boolean)
(using Context) = {

def truncateMethodType0(tpe: Type, n: Int): Type = {
tpe match{
case pt: PolyType => PolyType(pt.paramNames, pt.paramInfos, truncateMethodType0(pt.resType, n + 1))
case mt: MethodType =>
if (n == annotatedParamListIndex) MethodType(mt.paramInfos.take(paramIndex), mt.resType)
else MethodType(mt.paramInfos, truncateMethodType0(mt.resType, n + 1))
}
}

val truncatedMethodType = truncateMethodType0(prevMethodType, 0)
val forwarderDefSymbol = Symbols.newSymbol(
defdef.symbol.owner,
defdef.name,
defdef.symbol.flags &~
HasDefaultParams &~
(if (nextParamIndex == -1) Flags.EmptyFlags else Deferred) |
Invisible,
truncatedMethodType
)

val newParamLists: List[ParamClause] = paramLists.zipWithIndex.map{ case (ps, i) =>
if (i == annotatedParamListIndex) ps.take(paramIndex).map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
else {
if (isTypeClause(ps)) ps.map(p => copyParam2(p.asInstanceOf[TypeDef], forwarderDefSymbol))
else ps.map(p => copyParam(p.asInstanceOf[ValDef], forwarderDefSymbol))
}
}

val defaultOffset = paramLists
.iterator
.take(annotatedParamListIndex)
.filter(!isTypeClause(_))
.map(_.size)
.sum

val defaultCalls = Range(paramIndex, nextParamIndex).map(n =>
val inner = if (defdef.symbol.isConstructor) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.name, n + defaultOffset))
} else if (isCaseApply) {
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(termName("<init>"), n + defaultOffset))
} else {
This(defdef.symbol.owner.asClass)
.select(DefaultGetterName(defdef.name, n + defaultOffset))
}

newParamLists
.take(annotatedParamListIndex)
.map(_.map(p => ref(p.symbol)))
.foldLeft[Tree](inner){
case (lhs: Tree, newParams) =>
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
else Apply(lhs, newParams)
}
)

val forwarderInner: Tree = This(defdef.symbol.owner.asClass).select(nextSymbol)

val forwarderCallArgs =
newParamLists.zipWithIndex.map{case (ps, i) =>
if (i == annotatedParamListIndex) ps.map(p => ref(p.symbol)).take(nextParamIndex) ++ defaultCalls
else ps.map(p => ref(p.symbol))
}

lazy val forwarderCall0 = forwarderCallArgs.foldLeft[Tree](forwarderInner){
case (lhs: Tree, newParams) =>
if (newParams.headOption.exists(_.isInstanceOf[TypeTree])) TypeApply(lhs, newParams)
else Apply(lhs, newParams)
}

lazy val forwarderCall =
if (!defdef.symbol.isConstructor) forwarderCall0
else Block(List(forwarderCall0), Literal(Constant(())))

val forwarderDef = implicitly[Context].typeAssigner.assignType(
cpy.DefDef(defdef)(
name = forwarderDefSymbol.name,
paramss = newParamLists,
tpt = defdef.tpt,
rhs = if (nextParamIndex == -1) EmptyTree else forwarderCall
),
forwarderDefSymbol
)

forwarderDef
}

def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
cpy.DefDef(defdef)(
name = defdef.name,
paramss = defdef.paramss,
tpt = defdef.tpt,
rhs = Match(
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
startParamIndices.map { paramIndex =>
val Apply(select, args) = defdef.rhs: @unchecked
CaseDef(
Literal(Constant(paramIndex)),
EmptyTree,
Apply(
select,
args.take(paramIndex) ++
Range(paramIndex, paramCount).map(n =>
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
)
)
)
} ++ Seq(
CaseDef(
EmptyTree,
EmptyTree,
defdef.rhs
)
)
)
).setDefTree
}

def generateSyntheticDefs(tree: Tree)(using Context): (Option[Symbol], Seq[Tree]) = tree match{
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

val isCaseCopy =
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)

val isCaseApply =
defdef.name.toString == "apply" && defdef.symbol.owner.companionClass.is(CaseClass)

val isCaseFromProduct = defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass)

val annotated =
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol


annotated
.paramSymss
.zipWithIndex
.flatMap{case (paramClause, paramClauseIndex) =>
val annotationIndices = findUnrollAnnotations(paramClause)
if (annotationIndices.isEmpty) None
else Some((paramClauseIndex, annotationIndices))
} match{
case Nil => (None, Nil)
case Seq((paramClauseIndex, annotationIndices)) =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if (isCaseFromProduct) {
(Some(defdef.symbol), Seq(generateFromProduct(annotationIndices, paramCount, defdef)))
} else {
if (defdef.symbol.is(Deferred)){
(
Some(defdef.symbol),
(-1 +: annotationIndices :+ paramCount).sliding(2).toList.foldLeft((Seq.empty[DefDef], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
defdef,
defdef.symbol.info,
nextParamIndex,
paramIndex,
nextSymbol,
paramClauseIndex,
defdef.paramss,
isCaseApply
)
(forwarder +: defdefs, forwarder.symbol)
})._1
)

}else{

(
None,
(annotationIndices :+ paramCount).sliding(2).toList.reverse.foldLeft((Seq.empty[DefDef], defdef.symbol))((m, v) => ((m, v): @unchecked) match {
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
val forwarder = generateSingleForwarder(
defdef,
defdef.symbol.info,
paramIndex,
nextParamIndex,
nextSymbol,
paramClauseIndex,
defdef.paramss,
isCaseApply
)
(forwarder +: defdefs, forwarder.symbol)
})._1
)
}
}

case multiple => sys.error("Cannot have multiple parameter lists containing `@unroll` annotation")
}

case _ => (None, Nil)
}

override def transformTemplate(tmpl: tpd.Template)(using Context): tpd.Tree = {

val (removed0, generatedDefs) = tmpl.body.map(generateSyntheticDefs).unzip
val (_, generatedConstr) = generateSyntheticDefs(tmpl.constr)
val removed = removed0.flatten

super.transformTemplate(
cpy.Template(tmpl)(
tmpl.constr,
tmpl.parents,
tmpl.derived,
tmpl.self,
tmpl.body.filter(t => !removed.contains(t.symbol)) ++ generatedDefs.flatten ++ generatedConstr
)
)
}
}
4 changes: 4 additions & 0 deletions library/src/scala/annotation/unroll.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package scala.annotation

@experimental("under review as part of SIP-61")
final class unroll extends scala.annotation.StaticAnnotation

0 comments on commit 4c9c78c

Please sign in to comment.