From 1c36605e5cb245a2f7a7698f26c23c5ea8dc555a Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 24 Sep 2024 16:12:28 +0200 Subject: [PATCH] Correctly-ish desugar poly function context bounds in function types --- .../src/dotty/tools/dotc/ast/Desugar.scala | 23 +++++++++++-------- .../src/dotty/tools/dotc/typer/Typer.scala | 19 ++++++++------- .../contextbounds-for-poly-functions.scala | 9 ++++---- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index f34a552ecd37..ab7e9870e7db 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1214,31 +1214,36 @@ object desugar { */ def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction = val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked - val newTParams = tparams.map { + val newTParams = tparams.mapConserve { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) => TypeDef(name, ContextBounds(bounds, List.empty)) + case t => t } var idx = 0 - val collecedContextBounds = tparams.collect { + val collectedContextBounds = tparams.collect { case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty => - // TOOD(kπ) Should we handle non empty normal bounds here? name -> ctxBounds }.flatMap { case (name, ctxBounds) => ctxBounds.map { ctxBound => idx = idx + 1 ctxBound match - case ContextBoundTypeTree(_, _, ownName) => - ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) + case ctxBound @ ContextBoundTypeTree(tycon, paramName, ownName) => + if tree.isTerm then + ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given) + else + ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType case _ => makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given) } } val contextFunctionResult = - if collecedContextBounds.isEmpty then - fun + if collectedContextBounds.isEmpty then fun else - Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span) - PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) + val mods = EmptyModifiers.withFlags(Given) + val erasedParams = collectedContextBounds.map(_ => false) + Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span) + if collectedContextBounds.isEmpty then tree + else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span) /** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R * Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 521fd962c0c7..bf55103aa7a1 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -40,7 +40,7 @@ import annotation.tailrec import Implicits.* import util.Stats.record import config.Printers.{gadts, typr} -import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration} +import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration} import config.SourceVersion.* import rewrites.Rewrites, Rewrites.patch import staging.StagingLevel @@ -1142,7 +1142,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if templ1.parents.isEmpty && isFullyDefined(pt, ForceDegree.flipBottom) && isSkolemFree(pt) - && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity))) + && isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))) then templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil) for case parent: RefTree <- templ1.parents do @@ -1717,7 +1717,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt) else val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure) - val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt) + val args1 = args.mapConserve { + case cb: untpd.ContextBoundTypeTree => typed(cb) + case t => t + } + val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt) // if there are any erased classes, we need to re-do the typecheck. result match case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) => @@ -2458,12 +2462,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer if tycon.tpe.typeParams.nonEmpty then val tycon0 = tycon.withType(tycon.tpe.etaCollapse) typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil)) - else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then + else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then val tparamSplice = untpd.TypedSplice(typedExpr(tparam)) typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice)))) else def selfNote = - if Feature.enabled(modularity) then + if Feature.enabled(Feature.modularity) then " and\ndoes not have an abstract type member named `Self` either" else "" errorTree(tree, @@ -2482,7 +2486,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked val refinements1 = impl.body val seen = mutable.Set[Symbol]() - for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions + for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions typr.println(s"adding refinement $refinement") checkRefinementNonCyclic(refinement, refineCls, seen) val rsym = refinement.symbol @@ -2496,7 +2500,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer val member = refineCls.info.member(rsym.name) if (member.isOverloaded) report.error(OverloadInRefinement(rsym), refinement.srcPos) - } assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls) } @@ -4693,7 +4696,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName) case _ => errorTree(tree, em"cannot convert from $tree to an instance creation expression") - val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity)) + val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)) typed( untpd.Select( untpd.New(untpd.TypedSplice(tpt.withType(tycon))), diff --git a/tests/pos/contextbounds-for-poly-functions.scala b/tests/pos/contextbounds-for-poly-functions.scala index 00feedd66d71..90bd01ce6b6d 100644 --- a/tests/pos/contextbounds-for-poly-functions.scala +++ b/tests/pos/contextbounds-for-poly-functions.scala @@ -5,12 +5,13 @@ import scala.language.future trait Ord[X]: def compare(x: X, y: X): Int -val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 +// val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0 -val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +// val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 -// type Comparer = [X: Ord] => (x: X, y: X) => Boolean -// val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 +type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean +type Comparer = [X: Ord] => (x: X, y: X) => Boolean +val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0 // type Cmp[X] = (x: X, y: X) => Boolean // type Comparer2 = [X: Ord] => Cmp[X]