Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change retains annotation from using term arguments to using type arguments #22909

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2261,9 +2261,9 @@ object desugar {
AppliedTypeTree(ref(defn.SeqType), t),
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
else if op.name == nme.CC_REACH then
Apply(ref(defn.Caps_reachCapability), t :: Nil)
Annotated(t, New(ref(defn.ReachCapabilityAnnot.typeRef), Nil :: Nil))
else if op.name == nme.CC_READONLY then
Apply(ref(defn.Caps_readOnlyCapability), t :: Nil)
Annotated(t, New(ref(defn.ReadOnlyCapabilityAnnot.typeRef), Nil :: Nil))
else
assert(ctx.mode.isExpr || ctx.reporter.errorsReported || ctx.mode.is(Mode.Interactive), ctx.mode)
Select(t, op.name)
Expand Down
20 changes: 16 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
/** Property key for contextual Apply trees of the form `fn given arg` */
val KindOfApply: Property.StickyKey[ApplyKind] = Property.StickyKey()

val RetainsAnnot: Property.StickyKey[Unit] = Property.StickyKey()

// ------ Creation methods for untyped only -----------------

def Ident(name: Name)(implicit src: SourceFile): Ident = new Ident(name)
Expand Down Expand Up @@ -528,10 +530,17 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)

def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))

def makeCapsOf(tp: RefTree)(using Context): Tree =
TypeApply(capsInternalDot(nme.capsOf), tp :: Nil)
var annot: Tree = scalaAnnotationDot(annotName)
if annotName == tpnme.retainsCap then
annot = New(annot, Nil)
else
val trefs =
if refs.isEmpty then ref(defn.NothingType)
// TODO: choose a reduce direction
else refs.map(SingletonTypeTree).reduce[Tree]((a, b) => makeOrType(a, b))
annot = New(AppliedTypeTree(annot, trefs :: Nil), Nil)
annot.putAttachment(RetainsAnnot, ())
Annotated(parent, annot)

// Capture set variable `[C^]` becomes: `[C >: CapSet <: CapSet^{cap}]`
def makeCapsBound()(using Context): TypeBoundsTree =
Expand Down Expand Up @@ -563,6 +572,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def makeAndType(left: Tree, right: Tree)(using Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.andType.typeRef), left :: right :: Nil)

def makeOrType(left: Tree, right: Tree)(using Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.orType.typeRef), left :: right :: Nil)

def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers, isBackquoted: Boolean = false)(using Context): ValDef = {
val vdef = ValDef(pname, tpe, EmptyTree)
if (isBackquoted) vdef.pushAttachment(Backquoted, ())
Expand Down
17 changes: 8 additions & 9 deletions compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte

/** Reconstitute annotation tree from capture set */
override def tree(using Context) =
val elems = refs.elems.toList.map {
case cr: TermRef => ref(cr)
case cr: TermParamRef => untpd.Ident(cr.paramName).withType(cr)
case cr: ThisType => This(cr.cls)
case root(_) => ref(root.cap)
// TODO: Will crash if the type is an annotated type, for example `cap.rd`
}
val arg = repeated(elems, TypeTree(defn.AnyType))
New(symbol.typeRef, arg :: Nil)
if symbol == defn.RetainsCapAnnot then
New(symbol.typeRef, Nil)
else
val elems = refs.elems.toList
val trefs =
if elems.isEmpty then defn.NothingType
else elems.reduce[Type]((a, b) => OrType(a, b, soft = false))
New(AppliedType(symbol.typeRef, trefs :: Nil), Nil)

override def symbol(using Context) = cls

Expand Down
74 changes: 34 additions & 40 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,38 +190,25 @@ def ccState(using Context): CCState =

extension (tree: Tree)

/** Map tree with CaptureRef type to its type,
* map CapSet^{refs} to the `refs` references,
* throw IllegalCaptureRef otherwise
*/
def toCaptureRefs(using Context): List[CaptureRef] = tree match
case ReachCapabilityApply(arg) =>
arg.toCaptureRefs.map(_.reach)
case ReadOnlyCapabilityApply(arg) =>
arg.toCaptureRefs.map(_.readOnly)
case CapsOfApply(arg) =>
arg.toCaptureRefs
case _ => tree.tpe.dealiasKeepAnnots match
case ref: CaptureRef if ref.isTrackableRef =>
ref :: Nil
case AnnotatedType(parent, ann)
if ann.symbol.isRetains && parent.derivesFrom(defn.Caps_CapSet) =>
ann.tree.toCaptureSet.elems.toList
case tpe =>
throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer

/** Convert a @retains or @retainsByName annotation tree to the capture set it represents.
* For efficience, the result is cached as an Attachment on the tree.
*/
def toCaptureSet(using Context): CaptureSet =
tree.getAttachment(Captures) match
case Some(refs) => refs
case None =>
val refs = CaptureSet(tree.retainedElems.flatMap(_.toCaptureRefs)*)
//.showing(i"toCaptureSet $tree --> $result", capt)
val refs = CaptureSet(tree.retainedSet.retainedElements*)
tree.putAttachment(Captures, refs)
refs

def retainedSet(using Context): Type =
tree match
case Apply(TypeApply(_, refs :: Nil), _) => refs.tpe
case _ =>
if tree.symbol.maybeOwner == defn.RetainsCapAnnot
then root.cap
else NoType

/** The arguments of a @retains, @retainsCap or @retainsByName annotation */
def retainedElems(using Context): List[Tree] = tree match
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) =>
Expand All @@ -233,6 +220,21 @@ extension (tree: Tree)

extension (tp: Type)

def retainedElements(using Context): List[CaptureRef] = tp match
case ReachCapability(tp1) =>
tp1.reach :: Nil
case ReadOnlyCapability(tp1) =>
tp1.readOnly :: Nil
case tp: CaptureRef if tp.isTrackableRef =>
tp :: Nil
case tp: TypeRef if tp.symbol.isType && tp.derivesFrom(defn.Caps_CapSet) =>
tp :: Nil
case OrType(tp1, tp2) =>
tp1.retainedElements ++ tp2.retainedElements
case _ =>
if tp.isNothingType then Nil
else throw IllegalCaptureRef(tp)

/** Is this type a CaptureRef that can be tracked?
* This is true for
* - all ThisTypes and all TermParamRef,
Expand Down Expand Up @@ -655,7 +657,7 @@ extension (cls: ClassSymbol)
|| bc.is(CaptureChecked)
&& bc.givenSelfType.dealiasKeepAnnots.match
case CapturingType(_, refs) => refs.isAlwaysEmpty
case RetainingType(_, refs) => refs.isEmpty
case RetainingType(_, refs) => refs.retainedElements.isEmpty
case selfType =>
isCaptureChecking // At Setup we have not processed self types yet, so
// unless a self type is explicitly given, we can't tell
Expand Down Expand Up @@ -773,7 +775,7 @@ class CleanupRetains(using Context) extends TypeMap:
def apply(tp: Type): Type =
tp match
case AnnotatedType(tp, annot) if annot.symbol == defn.RetainsAnnot || annot.symbol == defn.RetainsByNameAnnot =>
RetainingType(tp, Nil, byName = annot.symbol == defn.RetainsByNameAnnot)
RetainingType(tp, defn.NothingType, byName = annot.symbol == defn.RetainsByNameAnnot)
case _ => mapOver(tp)

/** A typemap that follows aliases and keeps their transformed results if
Expand All @@ -792,26 +794,18 @@ trait FollowAliasesMap(using Context) extends TypeMap:
/** An extractor for `caps.reachCapability(ref)`, which is used to express a reach
* capability as a tree in a @retains annotation.
*/
object ReachCapabilityApply:
def unapply(tree: Apply)(using Context): Option[Tree] = tree match
case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
case _ => None
// object ReachCapabilityApply:
// def unapply(tree: Apply)(using Context): Option[Tree] = tree match
// case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
// case _ => None

/** An extractor for `caps.readOnlyCapability(ref)`, which is used to express a read-only
* capability as a tree in a @retains annotation.
*/
object ReadOnlyCapabilityApply:
def unapply(tree: Apply)(using Context): Option[Tree] = tree match
case Apply(ro, arg :: Nil) if ro.symbol == defn.Caps_readOnlyCapability => Some(arg)
case _ => None

/** An extractor for `caps.capsOf[X]`, which is used to express a generic capture set
* as a tree in a @retains annotation.
*/
object CapsOfApply:
def unapply(tree: TypeApply)(using Context): Option[Tree] = tree match
case TypeApply(capsOf, arg :: Nil) if capsOf.symbol == defn.Caps_capsOf => Some(arg)
case _ => None
// object ReadOnlyCapabilityApply:
// def unapply(tree: Apply)(using Context): Option[Tree] = tree match
// case Apply(ro, arg :: Nil) if ro.symbol == defn.Caps_readOnlyCapability => Some(arg)
// case _ => None

abstract class AnnotatedCapability(annotCls: Context ?=> ClassSymbol):
def apply(tp: Type)(using Context): AnnotatedType =
Expand Down
25 changes: 12 additions & 13 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,24 @@ object CheckCaptures:
* This check is performed at Typer.
*/
def checkWellformed(parent: Tree, ann: Tree)(using Context): Unit =
def check(elem: Tree, pos: SrcPos): Unit = elem.tpe match
def check(elem: Type, pos: SrcPos): Unit = elem match
case ref: CaptureRef =>
if !ref.isTrackableRef then
report.error(em"$elem cannot be tracked since it is not a parameter or local value", pos)
case tpe =>
report.error(em"$elem: $tpe is not a legal element of a capture set", pos)
for elem <- ann.retainedElems do
for elem <- ann.retainedSet.retainedElements do
elem match
case CapsOfApply(arg) =>
def isLegalCapsOfArg =
arg.symbol.isType && arg.symbol.info.derivesFrom(defn.Caps_CapSet)
if !isLegalCapsOfArg then
report.error(
em"""$arg is not a legal prefix for `^` here,
|is must be a type parameter or abstract type with a caps.CapSet upper bound.""",
elem.srcPos)
case ReachCapabilityApply(arg) => check(arg, elem.srcPos)
case ReadOnlyCapabilityApply(arg) => check(arg, elem.srcPos)
case _ => check(elem, elem.srcPos)
case ref: TypeRef =>
val refSym = ref.symbol
if refSym.isType && !refSym.info.derivesFrom(defn.Caps_CapSet) then
report.error(em"$elem is not a legal element of a capture set", ann.srcPos)
case ReachCapability(ref) =>
check(ref, ann.srcPos)
case ReadOnlyCapability(ref) =>
check(ref, ann.srcPos)
case _ =>
check(elem, ann.srcPos)

/** Under the sealed policy, report an error if some part of `tp` contains the
* root capability in its capture set or if it refers to a type parameter that
Expand Down
12 changes: 4 additions & 8 deletions compiler/src/dotty/tools/dotc/cc/RetainingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@ import Decorators.i
*/
object RetainingType:

def apply(tp: Type, refs: List[Tree], byName: Boolean = false)(using Context): Type =
def apply(tp: Type, typeElems: Type, byName: Boolean = false)(using Context): Type =
val annotCls = if byName then defn.RetainsByNameAnnot else defn.RetainsAnnot
val annotTree =
New(annotCls.typeRef,
Typed(
SeqLiteral(refs, TypeTree(defn.AnyType)),
TypeTree(defn.RepeatedParamClass.typeRef.appliedTo(defn.AnyType))) :: Nil)
val annotTree = New(AppliedType(annotCls.typeRef, typeElems :: Nil), Nil)
AnnotatedType(tp, Annotation(annotTree))

def unapply(tp: AnnotatedType)(using Context): Option[(Type, List[Tree])] =
def unapply(tp: AnnotatedType)(using Context): Option[(Type, Type)] =
val sym = tp.annot.symbol
if sym.isRetainsLike then
tp.annot match
case _: CaptureAnnotation =>
assert(ctx.mode.is(Mode.IgnoreCaptures), s"bad retains $tp at ${ctx.phase}")
None
case ann =>
Some((tp.parent, ann.tree.retainedElems))
Some((tp.parent, ann.tree.retainedSet))
else
None
end RetainingType
29 changes: 12 additions & 17 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
case CapturingType(_, refs) =>
!refs.isAlwaysEmpty
case RetainingType(parent, refs) =>
!refs.isEmpty
!refs.retainedElements.isEmpty
case tp: (TypeRef | AppliedType) =>
val sym = tp.typeSymbol
if sym.isClass
Expand Down Expand Up @@ -856,7 +856,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
&& !refs.isUniversal // if refs is {cap}, an added variable would not change anything
case RetainingType(parent, refs) =>
needsVariable(parent)
&& !refs.tpes.exists:
&& !refs.retainedElements.exists:
case ref: TermRef => ref.isCap
case _ => false
case AnnotatedType(parent, _) =>
Expand Down Expand Up @@ -951,19 +951,13 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
* @param tpt the tree for which an error or warning should be reported
*/
private def checkWellformed(parent: Type, ann: Tree, tpt: Tree)(using Context): Unit =
capt.println(i"checkWF post $parent ${ann.retainedElems} in $tpt")
var retained = ann.retainedElems.toArray
for i <- 0 until retained.length do
val refTree = retained(i)
val refs =
try refTree.toCaptureRefs
catch case ex: IllegalCaptureRef =>
report.error(em"Illegal capture reference: ${ex.getMessage.nn}", refTree.srcPos)
Nil
for ref <- refs do
capt.println(i"checkWF post $parent ${ann.retainedSet} in $tpt")
try
val retainedRefs = ann.retainedSet.retainedElements.toArray
for i <- 0 until retainedRefs.length do
val ref = retainedRefs(i)
def pos =
if refTree.span.exists then refTree.srcPos
else if ann.span.exists then ann.srcPos
if ann.span.exists then ann.srcPos
else tpt.srcPos

def check(others: CaptureSet, dom: Type | CaptureSet): Unit =
Expand All @@ -979,14 +973,15 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:

val others =
for
j <- 0 until retained.length if j != i
r <- retained(j).toCaptureRefs
j <- 0 until retainedRefs.length if j != i
r = retainedRefs(j)
if !r.isRootCapability
yield r
val remaining = CaptureSet(others*)
check(remaining, remaining)
end for
end for
catch case ex: IllegalCaptureRef =>
report.error(em"Illegal capture reference: ${ex.getMessage.nn}", tpt.srcPos)
end checkWellformed

/** Check well formed at post check time. We need to wait until after
Expand Down
6 changes: 1 addition & 5 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Definitions {
HKTypeLambda(argParamNames :+ "R".toTypeName, argVariances :+ Covariant)(
tl => List.fill(arity + 1)(TypeBounds.empty),
tl => RetainingType(underlyingClass.typeRef.appliedTo(tl.paramRefs),
ref(captureRoot.termRef) :: Nil)
captureRoot.termRef)
))
else
val cls = denot.asClass.classSymbol
Expand Down Expand Up @@ -998,9 +998,6 @@ class Definitions {
@tu lazy val Caps_Capability: ClassSymbol = requiredClass("scala.caps.Capability")
@tu lazy val Caps_CapSet: ClassSymbol = requiredClass("scala.caps.CapSet")
@tu lazy val CapsInternalModule: Symbol = requiredModule("scala.caps.internal")
@tu lazy val Caps_reachCapability: TermSymbol = CapsInternalModule.requiredMethod("reachCapability")
@tu lazy val Caps_readOnlyCapability: TermSymbol = CapsInternalModule.requiredMethod("readOnlyCapability")
@tu lazy val Caps_capsOf: TermSymbol = CapsInternalModule.requiredMethod("capsOf")
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
@tu lazy val Caps_unsafeAssumeSeparate: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumeSeparate")
Expand Down Expand Up @@ -1093,7 +1090,6 @@ class Definitions {
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.annotation.retains")
@tu lazy val RetainsCapAnnot: ClassSymbol = requiredClass("scala.annotation.retainsCap")
@tu lazy val RetainsByNameAnnot: ClassSymbol = requiredClass("scala.annotation.retainsByName")
@tu lazy val RetainsArgAnnot: ClassSymbol = requiredClass("scala.annotation.retainsArg")
@tu lazy val PublicInBinaryAnnot: ClassSymbol = requiredClass("scala.annotation.publicInBinary")
@tu lazy val WitnessNamesAnnot: ClassSymbol = requiredClass("scala.annotation.internal.WitnessNames")

Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ object Flags {
/** Tracked modifier for class parameter / a class with some tracked parameters */
val (Tracked @ _, _, Dependent @ _) = newFlags(46, "tracked")

val (CaptureParam @ _, _, _) = newFlags(47, "capture-param")

// ------------ Flags following this one are not pickled ----------------------------------

/** Symbol is not a member of its owner */
Expand Down Expand Up @@ -449,7 +451,7 @@ object Flags {

/** Flags representing source modifiers */
private val CommonSourceModifierFlags: FlagSet =
commonFlags(Private, Protected, Final, Case, Implicit, Given, Override, JavaStatic, Transparent, Erased)
commonFlags(Private, Protected, Final, Case, Implicit, Given, Override, JavaStatic, Transparent, Erased, CaptureParam)

val TypeSourceModifierFlags: FlagSet =
CommonSourceModifierFlags.toTypeFlags | Abstract | Sealed | Opaque | Open
Expand All @@ -469,7 +471,7 @@ object Flags {
val FromStartFlags: FlagSet = commonFlags(
Module, Package, Deferred, Method, Case, Enum, Param, ParamAccessor,
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
OuterOrCovariant, LabelOrContravariant, CaseAccessor, Tracked,
OuterOrCovariant, LabelOrContravariant, CaseAccessor, Tracked, CaptureParam,
Extension, NonMember, Implicit, Given, Permanent, Synthetic, Exported,
SuperParamAliasOrScala2x, Inline, Macro, ConstructorProxy, Invisible)

Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Mode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ object Mode {
*/
val ImplicitExploration: Mode = newMode(12, "ImplicitExploration")

/** We are currently inside a capture set.
* A term name could be a capture variable, so we need to
* check that it is valid to use as type name.
* Since this mode is only used during annotation typing,
* we can reuse the value of `ImplicitExploration` to save bits.
*/
val InCaptureSet: Mode = ImplicitExploration

/** We are currently unpickling Scala2 info */
val Scala2Unpickling: Mode = newMode(13, "Scala2Unpickling")

Expand Down
Loading
Loading