diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index e3387208e7c7..0504bce2f894 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -30,6 +30,7 @@ import config.Feature.{sourceVersion, modularity} import config.SourceVersion.* import scala.compiletime.uninitialized +import dotty.tools.dotc.transform.init.Util.tree /** This class creates symbols from definitions and imports and gives them * lazy types. @@ -1648,7 +1649,6 @@ class Namer { typer: Typer => * as an attachment on the ClassDef tree. */ def enterParentRefinementSyms(refinements: List[(Name, Type)]) = - println(s"For class $cls, entering parent refinements: $refinements") val refinedSyms = mutable.ListBuffer[Symbol]() for (name, tp) <- refinements do if decls.lookupEntry(name) == null then @@ -1658,7 +1658,6 @@ class Namer { typer: Typer => case _ => Synthetic | Deferred val s = newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered refinedSyms += s - println(s" entered $s") if refinedSyms.nonEmpty then typr.println(i"parent refinement symbols: ${refinedSyms.toList}") original.pushAttachment(ParentRefinements, refinedSyms.toList) @@ -1996,10 +1995,11 @@ class Namer { typer: Typer => */ def needsTracked(sym: Symbol, param: ValDef)(using Context) = !sym.is(Tracked) + && sym.maybeOwner.isConstructor && ( isContextBoundWitnessWithAbstractMembers(sym, param) || isReferencedInPublicSignatures(sym) - // || isPassedToTrackedParentParameter(sym, param) + || isPassedToTrackedParentParameter(sym, param) ) /** Under x.modularity, we add `tracked` to context bound witnesses @@ -2018,11 +2018,14 @@ class Namer { typer: Typer => def checkOwnerMemberSignatures(owner: Symbol): Boolean = owner.infoOrCompleter match case info: ClassInfo => - info.decls.filter(d => !d.isConstructor).exists(d => tpeContainsSymbolRef(d.info, accessorSyms)) + info.decls.filter(_.isTerm) + .filter(_ != sym.maybeOwner) + .exists(d => tpeContainsSymbolRef(d.info, accessorSyms)) case _ => false checkOwnerMemberSignatures(owner) def isPassedToTrackedParentParameter(sym: Symbol, param: ValDef)(using Context): Boolean = + // TODO(kπ) Add tracked if the param is passed as a tracked arg in parent. Can we touch the inheritance terms? val owner = sym.maybeOwner.maybeOwner val accessorSyms = maybeParamAccessors(owner, sym) owner.infoOrCompleter match @@ -2035,10 +2038,18 @@ class Namer { typer: Typer => case tpe: NamedType => tpe.prefix.exists && tpeContainsSymbolRef(tpe.prefix, syms) case _ => false - private def tpeContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean = - tpe.termSymbol.exists && syms.contains(tpe.termSymbol) - || tpe.argInfos.exists(tpeContainsSymbolRef(_, syms)) - || namedTypeWithPrefixContainsSymbolRef(tpe, syms) + private def tpeContainsSymbolRef(tpe0: Type, syms: List[Symbol])(using Context): Boolean = + val tpe = tpe0.dropAlias.widenExpr.dealias + tpe match + case m : MethodOrPoly => + m.paramInfos.exists(tpeContainsSymbolRef(_, syms)) + || tpeContainsSymbolRef(m.resultType, syms) + case r @ RefinedType(parent, _, refinedInfo) => tpeContainsSymbolRef(parent, syms) || tpeContainsSymbolRef(refinedInfo, syms) + case TypeBounds(lo, hi) => tpeContainsSymbolRef(lo, syms) || tpeContainsSymbolRef(hi, syms) + case t: Type => + tpe.termSymbol.exists && syms.contains(tpe.termSymbol) + || tpe.argInfos.exists(tpeContainsSymbolRef(_, syms)) + || namedTypeWithPrefixContainsSymbolRef(tpe, syms) private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] = owner.infoOrCompleter match diff --git a/tests/pos/infer-tracked-1.scala b/tests/pos/infer-tracked-1.scala new file mode 100644 index 000000000000..b4976a963074 --- /dev/null +++ b/tests/pos/infer-tracked-1.scala @@ -0,0 +1,34 @@ +import scala.language.experimental.modularity +import scala.language.future + +trait Ordering { + type T + def compare(t1:T, t2: T): Int +} + +class SetFunctor(val ord: Ordering) { + type Set = List[ord.T] + def empty: Set = Nil + + implicit class helper(s: Set) { + def add(x: ord.T): Set = x :: remove(x) + def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0) + def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0) + } +} + +object Test { + val orderInt = new Ordering { + type T = Int + def compare(t1: T, t2: T): Int = t1 - t2 + } + + val IntSet = new SetFunctor(orderInt) + import IntSet.* + + def main(args: Array[String]) = { + val set = IntSet.empty.add(6).add(8).add(23) + assert(!set.member(7)) + assert(set.member(8)) + } +} diff --git a/tests/pos/infer-tracked-parsercombinators-expanded.scala b/tests/pos/infer-tracked-parsercombinators-expanded.scala new file mode 100644 index 000000000000..63c6aec9e84a --- /dev/null +++ b/tests/pos/infer-tracked-parsercombinators-expanded.scala @@ -0,0 +1,65 @@ +import scala.language.experimental.modularity +import scala.language.future + +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +object test: + + class apply[C, E] extends Combinator[Apply[C, E]]: + type Context = C + type Element = E + extension(self: Apply[C, E]) + def parse(context: C): Option[E] = self.action(context) + + def apply[C, E]: apply[C, E] = new apply[C, E] + + class combine[A, B]( + val f: Combinator[A], + val s: Combinator[B] { type Context = f.Context} + ) extends Combinator[Combine[A, B]]: + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) + def parse(context: Context): Option[Element] = ??? + + def combine[A, B]( + _f: Combinator[A], + _s: Combinator[B] { type Context = _f.Context} + ) = new combine[A, B](_f, _s) + // cast is needed since the type of new combine[A, B](_f, _s) + // drops the required refinement. + + extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + + @main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val c = combine( + apply[mutable.ListBuffer[Int], Int], + apply[mutable.ListBuffer[Int], Int] + ) + val r = c.parse(m)(stream) // was type mismatch, now OK + val rc: Option[(Int, Int)] = r + } diff --git a/tests/pos/infer-tracked-parsercombinators-givens.scala b/tests/pos/infer-tracked-parsercombinators-givens.scala new file mode 100644 index 000000000000..eee522ed7285 --- /dev/null +++ b/tests/pos/infer-tracked-parsercombinators-givens.scala @@ -0,0 +1,55 @@ +import scala.language.experimental.modularity +import scala.language.future + +import collection.mutable + +/// A parser combinator. +trait Combinator[T]: + + /// The context from which elements are being parsed, typically a stream of tokens. + type Context + /// The element being parsed. + type Element + + extension (self: T) + /// Parses and returns an element from `context`. + def parse(context: Context): Option[Element] +end Combinator + +final case class Apply[C, E](action: C => Option[E]) +final case class Combine[A, B](first: A, second: B) + +given apply[C, E]: Combinator[Apply[C, E]] with { + type Context = C + type Element = E + extension(self: Apply[C, E]) { + def parse(context: C): Option[E] = self.action(context) + } +} + +given combine[A, B](using + val f: Combinator[A], + val s: Combinator[B] { type Context = f.Context } +): Combinator[Combine[A, B]] with { + type Context = f.Context + type Element = (f.Element, s.Element) + extension(self: Combine[A, B]) { + def parse(context: Context): Option[Element] = ??? + } +} + +extension [A] (buf: mutable.ListBuffer[A]) def popFirst() = + if buf.isEmpty then None + else try Some(buf.head) finally buf.remove(0) + +@main def hello: Unit = { + val source = (0 to 10).toList + val stream = source.to(mutable.ListBuffer) + + val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst()) + val m = Combine(n, n) + + val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context` + val rc: Option[(Int, Int)] = r + // it would be great if this worked +} diff --git a/tests/pos/infer-tracked-vector.scala b/tests/pos/infer-tracked-vector.scala new file mode 100644 index 000000000000..e748dc9cbe8e --- /dev/null +++ b/tests/pos/infer-tracked-vector.scala @@ -0,0 +1,82 @@ +import scala.language.experimental.modularity +import scala.language.future + +object typeparams: + sealed trait Nat + object Z extends Nat + final case class S[N <: Nat]() extends Nat + + type Zero = Z.type + type Succ[N <: Nat] = S[N] + + sealed trait Fin[N <: Nat] + case class FZero[N <: Nat]() extends Fin[Succ[N]] + case class FSucc[N <: Nat](pred: Fin[N]) extends Fin[Succ[N]] + + object Fin: + def zero[N <: Nat]: Fin[Succ[N]] = FZero() + def succ[N <: Nat](i: Fin[N]): Fin[Succ[N]] = FSucc(i) + + sealed trait Vec[A, N <: Nat] + case class VNil[A]() extends Vec[A, Zero] + case class VCons[A, N <: Nat](head: A, tail: Vec[A, N]) extends Vec[A, Succ[N]] + + object Vec: + def empty[A]: Vec[A, Zero] = VNil() + def cons[A, N <: Nat](head: A, tail: Vec[A, N]): Vec[A, Succ[N]] = VCons(head, tail) + + def get[A, N <: Nat](v: Vec[A, N], index: Fin[N]): A = (v, index) match + case (VCons(h, _), FZero()) => h + case (VCons(_, t), FSucc(pred)) => get(t, pred) + + def runVec(): Unit = + val v: Vec[Int, Succ[Succ[Succ[Zero]]]] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty))) + + println(s"Element at index 0: ${Vec.get(v, Fin.zero)}") + println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}") + println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}") + // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") // error + +// TODO(kπ) check if I can get it to work +// object typemembers: +// sealed trait Nat +// object Z extends Nat +// case class S() extends Nat: +// type N <: Nat + +// type Zero = Z.type +// type Succ[N1 <: Nat] = S { type N = N1 } + +// sealed trait Fin: +// type N <: Nat +// case class FZero[N1 <: Nat]() extends Fin: +// type N = Succ[N1] +// case class FSucc(tracked val pred: Fin) extends Fin: +// type N = Succ[pred.N] + +// object Fin: +// def zero[N1 <: Nat]: Fin { type N = Succ[N1] } = FZero[N1]() +// def succ[N1 <: Nat](i: Fin { type N = N1 }): Fin { type N = Succ[N1] } = FSucc(i) + +// sealed trait Vec[A]: +// type N <: Nat +// case class VNil[A]() extends Vec[A]: +// type N = Zero +// case class VCons[A](head: A, tracked val tail: Vec[A]) extends Vec[A]: +// type N = Succ[tail.N] + +// object Vec: +// def empty[A]: Vec[A] = VNil() +// def cons[A](head: A, tail: Vec[A]): Vec[A] = VCons(head, tail) + +// def get[A](v: Vec[A], index: Fin { type N = v.N }): A = (v, index) match +// case (VCons(h, _), FZero()) => h +// case (VCons(_, t), FSucc(pred)) => get(t, pred) + +// // def runVec(): Unit = +// val v: Vec[Int] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty))) + +// println(s"Element at index 0: ${Vec.get(v, Fin.zero)}") +// println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}") +// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}") +// // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") diff --git a/tests/pos/infer-tracked.scala b/tests/pos/infer-tracked.scala index 161c3b981a78..496508ffdc6c 100644 --- a/tests/pos/infer-tracked.scala +++ b/tests/pos/infer-tracked.scala @@ -10,13 +10,37 @@ class F(val x: C): class G(override val x: C) extends F(x) +class H(val x: C): + type T1 = x.T + val result: T1 = x.foo + +class I(val c: C, val t: c.T) + +case class J(c: C): + val result: c.T = c.foo + +case class K(c: C): + def result[B >: c.T]: B = c.foo + def Test = val c = new C: type T = Int def foo = 42 val f = new F(c) - val i: Int = f.result + val _: Int = f.result // val g = new G(c) - // val j: Int = g.result + // val _: Int = g.result + + val h = new H(c) + val _: Int = h.result + + val i = new I(c, c.foo) + val _: Int = i.t + + val j = J(c) + val _: Int = j.result + + val k = K(c) + val _: Int = k.result