Skip to content

Commit

Permalink
Implement basic version of desugaring context bounds for poly functions
Browse files Browse the repository at this point in the history
  • Loading branch information
KacperFKorban committed Sep 24, 2024
1 parent b8c5ecb commit d653f5a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 4 deletions.
27 changes: 27 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,33 @@ object desugar {
case _ => body
cpy.PolyFunction(tree)(tree.targs, stripped(tree.body)).asInstanceOf[PolyFunction]

/** Desugar [T_1 : B_1, ..., T_N : B_N] => (P_1, ..., P_M) => R
* Into [T_1, ..., T_N] => (P_1, ..., P_M) => (B_1, ..., B_N) ?=> R
*/
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
val newTParams = tparams.map {
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
TypeDef(name, ContextBounds(bounds, List.empty))
}
var idx = -1
val collecedContextBounds = tparams.collect {
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
// TOOD(kπ) Should we handle non empty normal bounds here?
name -> ctxBounds
}.flatMap { case (name, ctxBounds) =>
ctxBounds.map { ctxBound =>
idx = idx + 1
makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given)
}
}
val contextFunctionResult =
if collecedContextBounds.isEmpty then
fun
else
Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span)
PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)

/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }
*/
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ object Parsers {
def acceptsVariance =
this == Class || this == CaseClass || this == Hk
def acceptsCtxBounds =
!(this == Type || this == Hk)
!(this == Hk)
def acceptsWildcard =
this == Type || this == Hk

Expand Down Expand Up @@ -3421,7 +3421,7 @@ object Parsers {
*
* TypTypeParamClause::= ‘[’ TypTypeParam {‘,’ TypTypeParam} ‘]’
* TypTypeParam ::= {Annotation}
* (id | ‘_’) [HkTypeParamClause] TypeBounds
* (id | ‘_’) [HkTypeParamClause] TypeAndCtxBounds
*
* HkTypeParamClause ::= ‘[’ HkTypeParam {‘,’ HkTypeParam} ‘]’
* HkTypeParam ::= {Annotation} [‘+’ | ‘-’]
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1926,8 +1926,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer

def typedPolyFunction(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val tree1 = desugar.normalizePolyFunction(tree)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree1), pt)
else typedPolyFunctionValue(tree1, pt)
val tree2 = desugar.expandPolyFunctionContextBounds(tree1)
if (ctx.mode is Mode.Type) typed(desugar.makePolyFunctionType(tree2), pt)
else typedPolyFunctionValue(tree2, pt)

def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree =
val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked
Expand Down
15 changes: 15 additions & 0 deletions tests/pos/contextbounds-for-poly-functions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.language.experimental.modularity
import scala.language.future


trait Ord[X]:
def compare(x: X, y: X): Int

val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

// type Comparer = [X: Ord] => (x: X, y: X) => Boolean
// val less2: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0

// type Cmp[X] = (x: X, y: X) => Boolean
// type Comparer2 = [X: Ord] => Cmp[X]
// val less3: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0

0 comments on commit d653f5a

Please sign in to comment.