Skip to content

Commit

Permalink
Correctly-ish desugar poly function context bounds in function types
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Sep 24, 2024
1 parent 4060cf2 commit 1c36605
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
23 changes: 14 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1214,31 +1214,36 @@ object desugar {
*/
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
val newTParams = tparams.map {
val newTParams = tparams.mapConserve {
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
TypeDef(name, ContextBounds(bounds, List.empty))
case t => t
}
var idx = 0
val collecedContextBounds = tparams.collect {
val collectedContextBounds = tparams.collect {
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
// TOOD(kπ) Should we handle non empty normal bounds here?
name -> ctxBounds
}.flatMap { case (name, ctxBounds) =>
ctxBounds.map { ctxBound =>
idx = idx + 1
ctxBound match
case ContextBoundTypeTree(_, _, ownName) =>
ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given)
case ctxBound @ ContextBoundTypeTree(tycon, paramName, ownName) =>
if tree.isTerm then
ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given)
else
ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType
case _ =>
makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given)
}
}
val contextFunctionResult =
if collecedContextBounds.isEmpty then
fun
if collectedContextBounds.isEmpty then fun
else
Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span)
PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)
val mods = EmptyModifiers.withFlags(Given)
val erasedParams = collectedContextBounds.map(_ => false)
Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span)
if collectedContextBounds.isEmpty then tree
else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)

/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
Expand Down
19 changes: 11 additions & 8 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import annotation.tailrec
import Implicits.*
import util.Stats.record
import config.Printers.{gadts, typr}
import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration}
import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration}
import config.SourceVersion.*
import rewrites.Rewrites, Rewrites.patch
import staging.StagingLevel
Expand Down Expand Up @@ -1142,7 +1142,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if templ1.parents.isEmpty
&& isFullyDefined(pt, ForceDegree.flipBottom)
&& isSkolemFree(pt)
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity)))
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)))
then
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
for case parent: RefTree <- templ1.parents do
Expand Down Expand Up @@ -1717,7 +1717,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt)
else
val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure)
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)
val args1 = args.mapConserve {
case cb: untpd.ContextBoundTypeTree => typed(cb)
case t => t
}
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt)
// if there are any erased classes, we need to re-do the typecheck.
result match
case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) =>
Expand Down Expand Up @@ -2458,12 +2462,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
if tycon.tpe.typeParams.nonEmpty then
val tycon0 = tycon.withType(tycon.tpe.etaCollapse)
typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil))
else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
val tparamSplice = untpd.TypedSplice(typedExpr(tparam))
typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice))))
else
def selfNote =
if Feature.enabled(modularity) then
if Feature.enabled(Feature.modularity) then
" and\ndoes not have an abstract type member named `Self` either"
else ""
errorTree(tree,
Expand All @@ -2482,7 +2486,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
val refinements1 = impl.body
val seen = mutable.Set[Symbol]()
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
typr.println(s"adding refinement $refinement")
checkRefinementNonCyclic(refinement, refineCls, seen)
val rsym = refinement.symbol
Expand All @@ -2496,7 +2500,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
val member = refineCls.info.member(rsym.name)
if (member.isOverloaded)
report.error(OverloadInRefinement(rsym), refinement.srcPos)
}
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
}

Expand Down Expand Up @@ -4693,7 +4696,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName)
case _ =>
errorTree(tree, em"cannot convert from $tree to an instance creation expression")
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity))
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))
typed(
untpd.Select(
untpd.New(untpd.TypedSplice(tpt.withType(tycon))),
Expand Down
9 changes: 5 additions & 4 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import scala.language.future
trait Ord[X]:
def compare(x: X, y: X): Int

val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
// val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
// val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0

// type Comparer = [X: Ord] => (x: X, y: X) => Boolean
// val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean
type Comparer = [X: Ord] => (x: X, y: X) => Boolean
val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0

// type Cmp[X] = (x: X, y: X) => Boolean
// type Comparer2 = [X: Ord] => Cmp[X]
Expand Down

0 comments on commit 1c36605

Please sign in to comment.