diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 6fa63c21edaa..5af57c31e05f 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -674,7 +674,29 @@ class CheckCaptures extends Recheck, SymTransformer: i"Sealed type variable $pname", "be instantiated to", i"This is often caused by a local capability$where\nleaking as part of its result.", tree.srcPos) - handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt))) + val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt))) + if meth == defn.Caps_containsImpl then checkContains(tree) + res + end recheckTypeApply + + /** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked + * capability and assert that `{r} <:CS`. + */ + def checkContains(tree: TypeApply)(using Context): Unit = + tree.fun.knownType.widen match + case fntpe: PolyType => + tree.args match + case csArg :: refArg :: Nil => + val cs = csArg.knownType.captureSet + val ref = refArg.knownType + capt.println(i"check contains $cs , $ref") + ref match + case ref: CaptureRef if ref.isTracked => + checkElem(ref, cs, tree.srcPos) + case _ => + report.error(em"$refArg is not a tracked capability", refArg.srcPos) + case _ => + case _ => override def recheckBlock(tree: Block, pt: Type)(using Context): Type = inNestedLevel(super.recheckBlock(tree, pt)) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index fda12a5488ce..1d2f2b05feb4 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -993,15 +993,17 @@ class Definitions { @tu lazy val CapsModule: Symbol = requiredModule("scala.caps") @tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap") @tu lazy val Caps_Capability: TypeSymbol = CapsModule.requiredType("Capability") - @tu lazy val Caps_CapSet = requiredClass("scala.caps.CapSet") + @tu lazy val Caps_CapSet: ClassSymbol = requiredClass("scala.caps.CapSet") @tu lazy val Caps_reachCapability: TermSymbol = CapsModule.requiredMethod("reachCapability") @tu lazy val Caps_capsOf: TermSymbol = CapsModule.requiredMethod("capsOf") - @tu lazy val Caps_Exists = requiredClass("scala.caps.Exists") + @tu lazy val Caps_Exists: ClassSymbol = requiredClass("scala.caps.Exists") @tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe") @tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure") @tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox") @tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox") @tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg") + @tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability") + @tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl") @tu lazy val PureClass: Symbol = requiredClass("scala.Pure") diff --git a/library/src/scala/caps.scala b/library/src/scala/caps.scala index 1416a7b35f83..9700ed62738d 100644 --- a/library/src/scala/caps.scala +++ b/library/src/scala/caps.scala @@ -1,6 +1,6 @@ package scala -import annotation.{experimental, compileTimeOnly} +import annotation.{experimental, compileTimeOnly, retainsCap} @experimental object caps: @@ -19,6 +19,16 @@ import annotation.{experimental, compileTimeOnly} /** Carrier trait for capture set type parameters */ trait CapSet extends Any + /** A type constraint expressing that the capture set `C` needs to contain + * the capability `R` + */ + sealed trait Contains[C <: CapSet @retainsCap, R <: Singleton] + + /** The only implementation of `Contains`. The constraint that `{R} <: C` is + * added separately by the capture checker. + */ + given containsImpl[C <: CapSet @retainsCap, R <: Singleton]: Contains[C, R]() + @compileTimeOnly("Should be be used only internally by the Scala compiler") def capsOf[CS]: Any = ??? diff --git a/tests/neg-custom-args/captures/i21313.check b/tests/neg-custom-args/captures/i21313.check new file mode 100644 index 000000000000..37b944a97d68 --- /dev/null +++ b/tests/neg-custom-args/captures/i21313.check @@ -0,0 +1,11 @@ +-- Error: tests/neg-custom-args/captures/i21313.scala:6:27 ------------------------------------------------------------- +6 |def foo(x: Async) = x.await(???) // error + | ^ + | (x : Async) is not a tracked capability +-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i21313.scala:15:12 --------------------------------------- +15 | ac1.await(src2) // error + | ^^^^ + | Found: (src2 : Source[Int, caps.CapSet^{ac2}]^?) + | Required: Source[Int, caps.CapSet^{ac1}]^ + | + | longer explanation available when compiling with `-explain` diff --git a/tests/neg-custom-args/captures/i21313.scala b/tests/neg-custom-args/captures/i21313.scala new file mode 100644 index 000000000000..01bedb10aefd --- /dev/null +++ b/tests/neg-custom-args/captures/i21313.scala @@ -0,0 +1,15 @@ +import caps.CapSet + +trait Async: + def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T + +def foo(x: Async) = x.await(???) // error + +trait Source[+T, Cap^]: + final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap. + +def test(using ac1: Async^, ac2: Async^, x: String) = + val src1 = new Source[Int, CapSet^{ac1}] {} + ac1.await(src1) // ok + val src2 = new Source[Int, CapSet^{ac2}] {} + ac1.await(src2) // error diff --git a/tests/pos-custom-args/captures/i21313.scala b/tests/pos-custom-args/captures/i21313.scala new file mode 100644 index 000000000000..2fda6c0c0e45 --- /dev/null +++ b/tests/pos-custom-args/captures/i21313.scala @@ -0,0 +1,11 @@ +import caps.CapSet + +trait Async: + def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T + +trait Source[+T, Cap^]: + final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap. + +def test(using ac1: Async^, ac2: Async^, x: String) = + val src1 = new Source[Int, CapSet^{ac1}] {} + ac1.await(src1)