Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unapply #41

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ project/local-plugins.sbt
# VS Code
.vscode/

# IntelliJ
.idea/

# Dotty IDE
/.dotty-ide-dev-port
/.dotty-ide-artifact
Expand Down
140 changes: 140 additions & 0 deletions src/main/scala/dotty/xml/interpolator/internal/FillPlaceholders.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package dotty.xml.interpolator
package internal

import scala.language.implicitConversions
import scala.quoted.*
import dotty.xml.interpolator.internal.Tree.*

import scala.annotation.tailrec
import scala.xml.Node.unapplySeq
import scala.xml.{NamespaceBinding, SpecialNode}

type RetElem = scala.xml.Node | String

object FillPlaceholders {
// Assumes that placeholder indices are in increasing order left to right
def apply(nodes: Seq[Node], refNodes: Seq[scala.xml.Node]): Option[Seq[Any]] = {
findInNodes(nodes, refNodes)
}

// string equality where (null == "")
def strEq(s1: String, s2: String): Boolean =
(if s1 == null then "" else s1) == (if s2 == null then "" else s2)

// expects comments to be removed
private def findInNode(node: Node, refNode: scala.xml.Node): Option[Seq[RetElem]] = {
(node, refNode) match {
case (group: Group, refGroup: scala.xml.Group) => findInGroup(group, refGroup)
case (elem: Elem, refElem: scala.xml.Elem) => findInElem(elem, refElem)
case (text: Text, refText: scala.xml.Text) => findInText(text, refText)
case (placeholder: Placeholder, data) => fillPlaceholder(placeholder, data)
case (pcData: PCData, refPCData: scala.xml.PCData) => findInPCData(pcData, refPCData)
case (procInstr: ProcInstr, refProcInstr: scala.xml.ProcInstr) => findInProcInstr(procInstr, refProcInstr)
case (entityRef: EntityRef, refEntityRef: scala.xml.EntityRef) => findInEntityRef(entityRef, refEntityRef)
case (unparsed: Unparsed, refUnparsed: scala.xml.Unparsed) => findInUnparsed(unparsed, refUnparsed)

// assertions
case (_: Comment, _) => throw AssertionError("comment passed as first argument to findInNode")
case (_, _: scala.xml.Comment) => throw AssertionError("comment passed as second argument to findInNode")

// non-matching nodes
case _ => None
}
}

private def findInNodes(nodes: Seq[Node], refNodes: Seq[scala.xml.Node]): Option[Seq[RetElem]] = {
// remove comments
val nodes1 = nodes.filter(!_.isInstanceOf[Comment])
val refNodes1 = refNodes.filter(!_.isInstanceOf[scala.xml.Comment])

if nodes1.length != refNodes1.length then
return None

Some(nodes1
.zip(refNodes1)
.foldLeft(Seq[RetElem]()) { case (seq, (node, ref)) =>
findInNode(node, ref) match
case None => return None
case Some(fromNode) => seq ++ fromNode
})
}

private def findInGroup(group: Group, refGroup: scala.xml.Group): Option[Seq[RetElem]] =
findInNodes(group.nodes, refGroup.nodes)

private def findInElem(elem: Elem, refElem: scala.xml.Elem): Option[Seq[RetElem]] = {
if !strEq(elem.prefix, refElem.prefix) then
return None

if !strEq(elem.label, refElem.label) then
return None

for
fromAttributes <- findInAttributes(elem.attributes, refElem)
fromChildren <- findInNodes(elem.children, refElem.child)
yield
fromAttributes ++ fromChildren
}

private def findInAttributes(attributes: Seq[Attribute], refElem: scala.xml.Elem): Option[Seq[RetElem]] = {
Some(attributes
.foldLeft(Seq[RetElem]()) { (seq, attr) =>
refElem.attributes.find(_.prefixedKey == attr.name) match
case Some(refAttr) =>
findInNodes(attr.value, refAttr.value) match
case None => return None
case Some(fromNodes) => seq ++ fromNodes
case None if attr.prefix == "xmlns" =>
findInBindings(attr.key, attr.value, refElem.scope) match
case None => return None
case Some(nb) => seq ++ nb
case None if attr.prefix == "" && attr.key == "xmlns" =>
findInBindings("", attr.value, refElem.scope) match
case None => return None
case Some(nb) => seq ++ nb
case None =>
return None
})
}

private def findInBindings(key: String, value: Seq[Node], refBindings: scala.xml.NamespaceBinding): Option[Seq[String]] =
extension (scope: NamespaceBinding) @tailrec def find(p: NamespaceBinding => Boolean): Option[NamespaceBinding] = {
if scope == scala.xml.TopScope then
None
else if p(scope) then
Some(scope)
else
scope.parent.find(p)
}

refBindings.find(nb => strEq(nb.prefix, key)) match
case None => None
case Some(refNsBinding) =>
findInNsBinding(value, refNsBinding)

private def findInNsBinding(nodes: Seq[Node], refBinding: scala.xml.NamespaceBinding): Option[Seq[String]] =
nodes match
case Seq(Text(text)) if text == refBinding.uri =>
Some(Nil)
case Seq(p@Placeholder(_)) =>
fillPlaceholder(p, refBinding.uri)
case _ => None

private def findInText(text: Text, refText: scala.xml.Text): Option[Seq[Nothing]] =
if text.text == refText.text then Some(Nil) else None

private def fillPlaceholder(placeholder: Placeholder, data: scala.xml.Node | String): Option[Seq[data.type]] =
Some(Seq(data))

private def findInPCData(pcdata: PCData, refPCData: scala.xml.PCData): Option[Seq[Nothing]] =
if pcdata.data == refPCData.data then Some(Nil) else None

private def findInProcInstr(instr: ProcInstr, refInstr: scala.xml.ProcInstr): Option[Seq[Nothing]] =
if instr.target == refInstr.target then Some(Nil) else None

private def findInEntityRef(ref: EntityRef, refRef: scala.xml.EntityRef): Option[Seq[Nothing]] =
if ref.name == refRef.entityName then Some(Nil) else None

private def findInUnparsed(unparsed: Unparsed, refUnparsed: scala.xml.Unparsed): Option[Seq[Nothing]] =
if unparsed.data == refUnparsed.data then Some(Nil) else None
}
117 changes: 97 additions & 20 deletions src/main/scala/dotty/xml/interpolator/internal/Macro.scala
Original file line number Diff line number Diff line change
@@ -1,33 +1,111 @@
package dotty.xml.interpolator
package internal

import scala.quoted._

import scala.annotation.tailrec
import scala.quoted.*
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.xml.Node.unapplySeq

object Macro {

def impl(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Scope ?=> Any]], scope: Expr[Scope])(using qctx: Quotes): Expr[scala.xml.Node | scala.xml.NodeBuffer] = {
((strCtxExpr, argsExpr): @unchecked) match {
case ('{ StringContext(${Varargs(parts)}: _*) }, Varargs(args)) =>
val (xmlStr, offsets) = encode(parts)
implicit val ctx: XmlContext = new XmlContext(args, scope)
implicit val reporter: Reporter = new Reporter {
import quotes.reflect._
def impl(xmlStrCtxExpr: Expr[XML.StringContext], argsExpr: Expr[Seq[Scope ?=> Any]], scope: Expr[Scope])(using qctx: Quotes): Expr[scala.xml.Node | scala.xml.NodeBuffer] = {
val '{ xml(StringContext(${Varargs(parts)}: _*)) } = xmlStrCtxExpr
val Varargs(args) = argsExpr
val (xmlStr, offsets) = encode(parts)

given XmlContext = new XmlContext(args, scope)
given Reporter = new Reporter {
import quotes.reflect._

def error(msg: String, idx: Int): Unit = {
val (part, offset) = Reporter.from(idx, offsets, parts)
val pos = part.asTerm.pos
val (srcF, start) = (pos.sourceFile, pos.start)
report.error(msg, Position(srcF, start + offset, start + offset + 1))
}

def error(msg: String, expr: Expr[Any]): Unit = {
report.error(msg, expr)
}
}
implCore(xmlStr)
}

def error(msg: String, idx: Int): Unit = {
val (part, offset) = Reporter.from(idx, offsets, parts)
val pos = part.asTerm.pos
val (srcF, start) = (pos.sourceFile, pos.start)
report.error(msg, Position(srcF, start + offset, start + offset + 1))
}
def implUnapply(xmlStrCtxExpr: Expr[XML.StringContext], elemExpr: Expr[scala.xml.Node | scala.xml.NodeBuffer], scope: Expr[Scope])(using Quotes): Expr[Option[Seq[Any]]] = {
val '{ xml(StringContext(${Varargs(parts)}: _*)) } = xmlStrCtxExpr
val (xmlStr, offsets) = encode(parts)

def error(msg: String, expr: Expr[Any]): Unit = {
report.error(msg, expr)
}
given Reporter = new Reporter {
import quotes.reflect._

def error(msg: String, idx: Int): Unit = {
val (part, offset) = Reporter.from(idx, offsets, parts)
val pos = part.asTerm.pos
val (srcF, start) = (pos.sourceFile, pos.start)
report.error(msg, Position(srcF, start + offset, start + offset + 1))
}

def error(msg: String, expr: Expr[Any]): Unit = {
report.error(msg, expr)
}
}

val parsed = {
import Parse.{apply => parse}
import Transform.{apply => transform}
import Validate.{apply => validate}

val process = (
parse
andThen transform
andThen validate
)

process(xmlStr)
}

import scala.quoted.ToExpr.SeqToExpr

given ToExpr[Tree.Attribute] with
def apply(attr: Tree.Attribute)(using Quotes): Expr[Tree.Attribute] =
val Tree.Attribute(name, value) = attr
val valueExpr = Expr[Seq[Tree.Node]](value)
'{ Tree.Attribute(${ Expr(name) }, ${ valueExpr }) }

given ToExpr[Tree.Node] with
def apply(node: Tree.Node)(using Quotes): Expr[Tree.Node] =
node match {
case Tree.Group(nodes) =>
'{ Tree.Group(${ Expr(nodes) }) }
case Tree.Elem(name, attrs, children, end) =>
'{ Tree.Elem(${ Expr(name) }, ${ Expr(attrs) }, ${ Expr(children) }, ${ Expr(end) }) }
case Tree.Text(text) =>
'{ Tree.Text(${ Expr(text) }) }
case Tree.Comment(text) =>
'{ Tree.Comment(${ Expr(text) }) }
case Tree.Placeholder(id) =>
'{ Tree.Placeholder(${ Expr(id) }) }
case Tree.PCData(data) =>
'{ Tree.PCData(${ Expr(data) }) }
case Tree.ProcInstr(target, proctext) =>
'{ Tree.ProcInstr(${ Expr(target) }, ${ Expr(proctext) }) }
case Tree.EntityRef(name) =>
'{ Tree.EntityRef(${ Expr(name) }) }
case Tree.Unparsed(data) =>
'{ Tree.Unparsed(${ Expr(data) }) }
}
implCore(xmlStr)

import FillPlaceholders.{apply => fill_placeholders}

'{
val nodes = ${elemExpr} match {
case e: scala.xml.Node => Seq(e)
case e: scala.xml.NodeBuffer => e
}

fill_placeholders(${ Expr(parsed) }, nodes)
}
}

Expand Down Expand Up @@ -57,7 +135,6 @@ object Macro {
}

private def implCore(xmlStr: String)(using XmlContext, Reporter, Quotes): Expr[scala.xml.Node | scala.xml.NodeBuffer] = {

import Parse.{apply => parse}
import Transform.{apply => transform}
import Validate.{apply => validate}
Expand All @@ -81,7 +158,7 @@ object Macro {

def appendPart(part: Expr[String]) = {
bf += sb.length
sb ++= part.valueOrError
sb ++= part.valueOrAbort
bf += sb.length
}

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/dotty/xml/interpolator/internal/Parse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,5 +186,5 @@ object Parse extends JavaTokenParsers with TokenTests {

private def Eq = S.? ~ "=" ~ S.?

private def Placeholder = positioned(HoleStart ~ HoleChar.* ^^ { case char ~ chars => Tree.Placeholder((char :: chars).length -1) })
private def Placeholder = positioned(HoleStart ~ HoleChar.* ^^ { case char ~ chars => Tree.Placeholder((char :: chars).length - 1) })
}
13 changes: 11 additions & 2 deletions src/main/scala/dotty/xml/interpolator/package.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package dotty.xml.interpolator

import scala.quoted._
import scala.quoted.*

type Scope = scala.xml.NamespaceBinding
implicit val top: Scope = scala.xml.TopScope

extension (inline ctx: StringContext) transparent inline def xml (inline args: (Scope ?=> Any)*)(using scope: Scope): Any =
object XML:
opaque type StringContext = scala.StringContext
def apply(ctx: scala.StringContext): StringContext = ctx

extension (ctx: StringContext) def xml: XML.StringContext = XML(ctx)

extension (inline ctx: XML.StringContext) transparent inline def apply(inline args: (Scope ?=> Any)*)(using scope: Scope): Any =
${ dotty.xml.interpolator.internal.Macro.impl('ctx, 'args, 'scope) }

extension (inline ctx: XML.StringContext) transparent inline def unapplySeq(inline elem: scala.xml.Node | scala.xml.NodeBuffer)(using scope: Scope): Option[Seq[Any]] =
${ dotty.xml.interpolator.internal.Macro.implUnapply('ctx, 'elem, 'scope) }
Loading