Skip to content

Commit

Permalink
Check for parameter references in type bounds when infering tracked
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Sep 23, 2024
1 parent 7348c02 commit 4bc5d24
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 10 deletions.
27 changes: 19 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import config.Feature.{sourceVersion, modularity}
import config.SourceVersion.*

import scala.compiletime.uninitialized
import dotty.tools.dotc.transform.init.Util.tree

/** This class creates symbols from definitions and imports and gives them
* lazy types.
Expand Down Expand Up @@ -1648,7 +1649,6 @@ class Namer { typer: Typer =>
* as an attachment on the ClassDef tree.
*/
def enterParentRefinementSyms(refinements: List[(Name, Type)]) =
println(s"For class $cls, entering parent refinements: $refinements")
val refinedSyms = mutable.ListBuffer[Symbol]()
for (name, tp) <- refinements do
if decls.lookupEntry(name) == null then
Expand All @@ -1658,7 +1658,6 @@ class Namer { typer: Typer =>
case _ => Synthetic | Deferred
val s = newSymbol(cls, name, flags, tp, coord = original.rhs.span.startPos).entered
refinedSyms += s
println(s" entered $s")
if refinedSyms.nonEmpty then
typr.println(i"parent refinement symbols: ${refinedSyms.toList}")
original.pushAttachment(ParentRefinements, refinedSyms.toList)
Expand Down Expand Up @@ -1996,10 +1995,11 @@ class Namer { typer: Typer =>
*/
def needsTracked(sym: Symbol, param: ValDef)(using Context) =
!sym.is(Tracked)
&& sym.maybeOwner.isConstructor
&& (
isContextBoundWitnessWithAbstractMembers(sym, param)
|| isReferencedInPublicSignatures(sym)
// || isPassedToTrackedParentParameter(sym, param)
|| isPassedToTrackedParentParameter(sym, param)
)

/** Under x.modularity, we add `tracked` to context bound witnesses
Expand All @@ -2018,11 +2018,14 @@ class Namer { typer: Typer =>
def checkOwnerMemberSignatures(owner: Symbol): Boolean =
owner.infoOrCompleter match
case info: ClassInfo =>
info.decls.filter(d => !d.isConstructor).exists(d => tpeContainsSymbolRef(d.info, accessorSyms))
info.decls.filter(_.isTerm)
.filter(_ != sym.maybeOwner)
.exists(d => tpeContainsSymbolRef(d.info, accessorSyms))
case _ => false
checkOwnerMemberSignatures(owner)

def isPassedToTrackedParentParameter(sym: Symbol, param: ValDef)(using Context): Boolean =
// TODO(kπ) Add tracked if the param is passed as a tracked arg in parent. Can we touch the inheritance terms?
val owner = sym.maybeOwner.maybeOwner
val accessorSyms = maybeParamAccessors(owner, sym)
owner.infoOrCompleter match
Expand All @@ -2035,10 +2038,18 @@ class Namer { typer: Typer =>
case tpe: NamedType => tpe.prefix.exists && tpeContainsSymbolRef(tpe.prefix, syms)
case _ => false

private def tpeContainsSymbolRef(tpe: Type, syms: List[Symbol])(using Context): Boolean =
tpe.termSymbol.exists && syms.contains(tpe.termSymbol)
|| tpe.argInfos.exists(tpeContainsSymbolRef(_, syms))
|| namedTypeWithPrefixContainsSymbolRef(tpe, syms)
private def tpeContainsSymbolRef(tpe0: Type, syms: List[Symbol])(using Context): Boolean =
val tpe = tpe0.dropAlias.widenExpr.dealias
tpe match
case m : MethodOrPoly =>
m.paramInfos.exists(tpeContainsSymbolRef(_, syms))
|| tpeContainsSymbolRef(m.resultType, syms)
case r @ RefinedType(parent, _, refinedInfo) => tpeContainsSymbolRef(parent, syms) || tpeContainsSymbolRef(refinedInfo, syms)
case TypeBounds(lo, hi) => tpeContainsSymbolRef(lo, syms) || tpeContainsSymbolRef(hi, syms)
case t: Type =>
tpe.termSymbol.exists && syms.contains(tpe.termSymbol)
|| tpe.argInfos.exists(tpeContainsSymbolRef(_, syms))
|| namedTypeWithPrefixContainsSymbolRef(tpe, syms)

private def maybeParamAccessors(owner: Symbol, sym: Symbol)(using Context): List[Symbol] =
owner.infoOrCompleter match
Expand Down
34 changes: 34 additions & 0 deletions tests/pos/infer-tracked-1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import scala.language.experimental.modularity
import scala.language.future

trait Ordering {
type T
def compare(t1:T, t2: T): Int
}

class SetFunctor(val ord: Ordering) {
type Set = List[ord.T]
def empty: Set = Nil

implicit class helper(s: Set) {
def add(x: ord.T): Set = x :: remove(x)
def remove(x: ord.T): Set = s.filter(e => ord.compare(x, e) != 0)
def member(x: ord.T): Boolean = s.exists(e => ord.compare(x, e) == 0)
}
}

object Test {
val orderInt = new Ordering {
type T = Int
def compare(t1: T, t2: T): Int = t1 - t2
}

val IntSet = new SetFunctor(orderInt)
import IntSet.*

def main(args: Array[String]) = {
val set = IntSet.empty.add(6).add(8).add(23)
assert(!set.member(7))
assert(set.member(8))
}
}
65 changes: 65 additions & 0 deletions tests/pos/infer-tracked-parsercombinators-expanded.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import scala.language.experimental.modularity
import scala.language.future

import collection.mutable

/// A parser combinator.
trait Combinator[T]:

/// The context from which elements are being parsed, typically a stream of tokens.
type Context
/// The element being parsed.
type Element

extension (self: T)
/// Parses and returns an element from `context`.
def parse(context: Context): Option[Element]
end Combinator

final case class Apply[C, E](action: C => Option[E])
final case class Combine[A, B](first: A, second: B)

object test:

class apply[C, E] extends Combinator[Apply[C, E]]:
type Context = C
type Element = E
extension(self: Apply[C, E])
def parse(context: C): Option[E] = self.action(context)

def apply[C, E]: apply[C, E] = new apply[C, E]

class combine[A, B](
val f: Combinator[A],
val s: Combinator[B] { type Context = f.Context}
) extends Combinator[Combine[A, B]]:
type Context = f.Context
type Element = (f.Element, s.Element)
extension(self: Combine[A, B])
def parse(context: Context): Option[Element] = ???

def combine[A, B](
_f: Combinator[A],
_s: Combinator[B] { type Context = _f.Context}
) = new combine[A, B](_f, _s)
// cast is needed since the type of new combine[A, B](_f, _s)
// drops the required refinement.

extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
if buf.isEmpty then None
else try Some(buf.head) finally buf.remove(0)

@main def hello: Unit = {
val source = (0 to 10).toList
val stream = source.to(mutable.ListBuffer)

val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
val m = Combine(n, n)

val c = combine(
apply[mutable.ListBuffer[Int], Int],
apply[mutable.ListBuffer[Int], Int]
)
val r = c.parse(m)(stream) // was type mismatch, now OK
val rc: Option[(Int, Int)] = r
}
55 changes: 55 additions & 0 deletions tests/pos/infer-tracked-parsercombinators-givens.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import scala.language.experimental.modularity
import scala.language.future

import collection.mutable

/// A parser combinator.
trait Combinator[T]:

/// The context from which elements are being parsed, typically a stream of tokens.
type Context
/// The element being parsed.
type Element

extension (self: T)
/// Parses and returns an element from `context`.
def parse(context: Context): Option[Element]
end Combinator

final case class Apply[C, E](action: C => Option[E])
final case class Combine[A, B](first: A, second: B)

given apply[C, E]: Combinator[Apply[C, E]] with {
type Context = C
type Element = E
extension(self: Apply[C, E]) {
def parse(context: C): Option[E] = self.action(context)
}
}

given combine[A, B](using
val f: Combinator[A],
val s: Combinator[B] { type Context = f.Context }
): Combinator[Combine[A, B]] with {
type Context = f.Context
type Element = (f.Element, s.Element)
extension(self: Combine[A, B]) {
def parse(context: Context): Option[Element] = ???
}
}

extension [A] (buf: mutable.ListBuffer[A]) def popFirst() =
if buf.isEmpty then None
else try Some(buf.head) finally buf.remove(0)

@main def hello: Unit = {
val source = (0 to 10).toList
val stream = source.to(mutable.ListBuffer)

val n = Apply[mutable.ListBuffer[Int], Int](s => s.popFirst())
val m = Combine(n, n)

val r = m.parse(stream) // error: type mismatch, found `mutable.ListBuffer[Int]`, required `?1.Context`
val rc: Option[(Int, Int)] = r
// it would be great if this worked
}
82 changes: 82 additions & 0 deletions tests/pos/infer-tracked-vector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import scala.language.experimental.modularity
import scala.language.future

object typeparams:
sealed trait Nat
object Z extends Nat
final case class S[N <: Nat]() extends Nat

type Zero = Z.type
type Succ[N <: Nat] = S[N]

sealed trait Fin[N <: Nat]
case class FZero[N <: Nat]() extends Fin[Succ[N]]
case class FSucc[N <: Nat](pred: Fin[N]) extends Fin[Succ[N]]

object Fin:
def zero[N <: Nat]: Fin[Succ[N]] = FZero()
def succ[N <: Nat](i: Fin[N]): Fin[Succ[N]] = FSucc(i)

sealed trait Vec[A, N <: Nat]
case class VNil[A]() extends Vec[A, Zero]
case class VCons[A, N <: Nat](head: A, tail: Vec[A, N]) extends Vec[A, Succ[N]]

object Vec:
def empty[A]: Vec[A, Zero] = VNil()
def cons[A, N <: Nat](head: A, tail: Vec[A, N]): Vec[A, Succ[N]] = VCons(head, tail)

def get[A, N <: Nat](v: Vec[A, N], index: Fin[N]): A = (v, index) match
case (VCons(h, _), FZero()) => h
case (VCons(_, t), FSucc(pred)) => get(t, pred)

def runVec(): Unit =
val v: Vec[Int, Succ[Succ[Succ[Zero]]]] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty)))

println(s"Element at index 0: ${Vec.get(v, Fin.zero)}")
println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}")
println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}")
// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}") // error

// TODO(kπ) check if I can get it to work
// object typemembers:
// sealed trait Nat
// object Z extends Nat
// case class S() extends Nat:
// type N <: Nat

// type Zero = Z.type
// type Succ[N1 <: Nat] = S { type N = N1 }

// sealed trait Fin:
// type N <: Nat
// case class FZero[N1 <: Nat]() extends Fin:
// type N = Succ[N1]
// case class FSucc(tracked val pred: Fin) extends Fin:
// type N = Succ[pred.N]

// object Fin:
// def zero[N1 <: Nat]: Fin { type N = Succ[N1] } = FZero[N1]()
// def succ[N1 <: Nat](i: Fin { type N = N1 }): Fin { type N = Succ[N1] } = FSucc(i)

// sealed trait Vec[A]:
// type N <: Nat
// case class VNil[A]() extends Vec[A]:
// type N = Zero
// case class VCons[A](head: A, tracked val tail: Vec[A]) extends Vec[A]:
// type N = Succ[tail.N]

// object Vec:
// def empty[A]: Vec[A] = VNil()
// def cons[A](head: A, tail: Vec[A]): Vec[A] = VCons(head, tail)

// def get[A](v: Vec[A], index: Fin { type N = v.N }): A = (v, index) match
// case (VCons(h, _), FZero()) => h
// case (VCons(_, t), FSucc(pred)) => get(t, pred)

// // def runVec(): Unit =
// val v: Vec[Int] = Vec.cons(1, Vec.cons(2, Vec.cons(3, Vec.empty)))

// println(s"Element at index 0: ${Vec.get(v, Fin.zero)}")
// println(s"Element at index 1: ${Vec.get(v, Fin.succ(Fin.zero))}")
// println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.zero)))}")
// // println(s"Element at index 2: ${Vec.get(v, Fin.succ(Fin.succ(Fin.succ(Fin.zero))))}")
28 changes: 26 additions & 2 deletions tests/pos/infer-tracked.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,37 @@ class F(val x: C):

class G(override val x: C) extends F(x)

class H(val x: C):
type T1 = x.T
val result: T1 = x.foo

class I(val c: C, val t: c.T)

case class J(c: C):
val result: c.T = c.foo

case class K(c: C):
def result[B >: c.T]: B = c.foo

def Test =
val c = new C:
type T = Int
def foo = 42

val f = new F(c)
val i: Int = f.result
val _: Int = f.result

// val g = new G(c)
// val j: Int = g.result
// val _: Int = g.result

val h = new H(c)
val _: Int = h.result

val i = new I(c, c.foo)
val _: Int = i.t

val j = J(c)
val _: Int = j.result

val k = K(c)
val _: Int = k.result

0 comments on commit 4bc5d24

Please sign in to comment.