Skip to content

Commit

Permalink
[CORE] Pullout pre/post project for generate (apache#4952)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Mar 25, 2024
1 parent cb63b38 commit 59ff500
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
): GenerateExecTransformerBase = {
CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child)
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.BuildSideRelation
Expand Down Expand Up @@ -677,4 +677,12 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
): GenerateExecTransformerBase = {
GenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child)
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = {
PullOutGenerateProjectHelper.pullOutPreProject(generate)
}

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = {
PullOutGenerateProjectHelper.pullOutPostProject(generate)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, ExpressionNames}
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.execution.GenerateExecTransformer.supportsGenerate
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater}
import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType}
import org.apache.spark.sql.types.IntegerType

import com.google.common.collect.Lists
import com.google.protobuf.StringValue

import scala.collection.JavaConverters._
import scala.collection.mutable

case class GenerateExecTransformer(
generator: Generator,
Expand All @@ -62,26 +61,12 @@ case class GenerateExecTransformer(
override protected def doGeneratorValidate(
generator: Generator,
outer: Boolean): ValidationResult = {
if (outer) {
return ValidationResult.notOk(s"Velox backend does not support outer")
}
generator match {
case _: JsonTuple =>
ValidationResult.notOk(s"Velox backend does not support this json_tuple")
case _: ExplodeBase =>
ValidationResult.ok
case Inline(child) =>
child match {
case AttributeReference(_, ArrayType(_: StructType, _), _, _) =>
ValidationResult.ok
case _ =>
// TODO: Support Literal/CreateArray.
ValidationResult.notOk(
s"Velox backend does not support inline with expression " +
s"${child.getClass.getSimpleName}.")
}
case _ =>
ValidationResult.ok
if (!supportsGenerate(generator, outer)) {
ValidationResult.notOk(
s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}" +
s", outer: $outer")
} else {
ValidationResult.ok
}
}

Expand All @@ -91,30 +76,13 @@ case class GenerateExecTransformer(
generatorNode: ExpressionNode,
validation: Boolean): RelNode = {
val operatorId = context.nextOperatorId(this.nodeName)

val newInput = if (!validation) {
applyPreProject(inputRel, context, operatorId)
} else {
// No need to validate the pre-projection. The generator output has been validated in
// doGeneratorValidate.
inputRel
}

val generateRel = RelBuilder.makeGenerateRel(
newInput,
RelBuilder.makeGenerateRel(
inputRel,
generatorNode,
requiredChildOutputNodes.asJava,
getExtensionNode(validation),
context,
operatorId)

if (!validation) {
applyPostProject(generateRel, context, operatorId)
} else {
// No need to validate the post-projection on the native side as
// it only flattens the generator's output.
generateRel
}
}

private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = {
Expand All @@ -141,92 +109,95 @@ case class GenerateExecTransformer(
getExtensionNodeForValidation
}
}
}

// Select child outputs and append generator input.
private def applyPreProject(
inputRel: RelNode,
context: SubstraitContext,
operatorId: Long
): RelNode = {
val projectExpressions: Seq[ExpressionNode] =
child.output.indices
.map(ExpressionBuilder.makeSelection(_)) :+
ExpressionConverter
.replaceWithExpressionTransformer(
generator.asInstanceOf[UnaryExpression].child,
child.output)
.doTransform(context.registeredFunction)
object GenerateExecTransformer {
def supportsGenerate(generator: Generator, outer: Boolean): Boolean = {
// TODO: supports outer and remove this param.
if (outer) {
false
} else {
generator match {
case _: Inline | _: ExplodeBase =>
true
case _ =>
false
}
}
}
}

RelBuilder.makeProjectRel(
inputRel,
projectExpressions.asJava,
context,
operatorId,
child.output.size)
object PullOutGenerateProjectHelper extends PullOutProjectHelper {
def pullOutPreProject(generate: GenerateExec): SparkPlan = {
if (GenerateExecTransformer.supportsGenerate(generate.generator, generate.outer)) {
val newGeneratorChildren = generate.generator match {
case _: Inline | _: ExplodeBase =>
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
// The new child should be either the original Attribute,
// or an Alias to other expressions.
val generatorAttr = replaceExpressionWithAttribute(
generate.generator.asInstanceOf[UnaryExpression].child,
expressionMap,
replaceBoundReference = true)
val newGeneratorChild = if (expressionMap.isEmpty) {
// generator.child is Attribute
generatorAttr.asInstanceOf[Attribute]
} else {
// generator.child is other expression, e.g Literal/CreateArray/CreateMap
expressionMap.values.head
}
Seq(newGeneratorChild)
case _ =>
// Unreachable.
throw new IllegalStateException(
s"Generator ${generate.generator.getClass.getSimpleName} is not supported.")
}
// Avoid using elimainateProjectList to create the project list
// because newGeneratorChild can be a duplicated Attribute in generate.child.output.
// The native side identifies the last field of projection as generator's input.
generate.copy(
generator =
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
child = ProjectExec(generate.child.output ++ newGeneratorChildren, generate.child)
)
} else {
generate
}
}

// There are 3 types of CollectionGenerator in spark: Explode, PosExplode and Inline.
// Adds postProject for PosExplode and Inline.
private def applyPostProject(
generateRel: RelNode,
context: SubstraitContext,
operatorId: Long): RelNode = {
generator match {
case Inline(_) =>
val requiredOutput = requiredChildOutputNodes.indices.map {
ExpressionBuilder.makeSelection(_)
}
val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map {
i =>
val selectionNode = ExpressionBuilder.makeSelection(requiredOutput.size)
selectionNode.addNestedChildIdx(i)
}
RelBuilder.makeProjectRel(
generateRel,
(requiredOutput ++ flattenStruct).asJava,
context,
operatorId,
1 + requiredOutput.size // 1 stands for the inner struct field from array.
)
case PosExplode(posExplodeChild) =>
// Ordinal populated by Velox UnnestNode starts with 1.
// Need to substract 1 to align with Spark's output.
val unnestedSize = posExplodeChild.dataType match {
case _: MapType => 2
case _: ArrayType => 1
}
val subFunctionName = ConverterUtils.makeFuncName(
ExpressionNames.SUBTRACT,
Seq(LongType, LongType),
FunctionConfig.OPT)
val functionMap = context.registeredFunction
val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap, subFunctionName)
val literalNode = ExpressionBuilder.makeLiteral(1L, LongType, false)
val ordinalNode = ExpressionBuilder.makeCast(
TypeBuilder.makeI32(false),
ExpressionBuilder.makeScalarFunction(
addFunctionId,
Lists.newArrayList(
ExpressionBuilder.makeSelection(requiredChildOutputNodes.size + unnestedSize),
literalNode),
ConverterUtils.getTypeNode(LongType, generator.elementSchema.head.nullable)
),
true // Generated ordinal column shouldn't have null.
)
val requiredChildNodes =
requiredChildOutputNodes.indices.map(ExpressionBuilder.makeSelection(_))
val unnestColumns = (0 until unnestedSize)
.map(i => ExpressionBuilder.makeSelection(i + requiredChildOutputNodes.size))
val generatorOutput: Seq[ExpressionNode] =
(requiredChildNodes :+ ordinalNode) ++ unnestColumns
RelBuilder.makeProjectRel(
generateRel,
generatorOutput.asJava,
context,
operatorId,
generatorOutput.size
)
case _ => generateRel
def pullOutPostProject(generate: GenerateExec): SparkPlan = {
if (GenerateExecTransformer.supportsGenerate(generate.generator, generate.outer)) {
generate.generator match {
case PosExplode(_) =>
val originalOrdinal = generate.generatorOutput.head
val ordinal = {
val subtract = Subtract(Cast(originalOrdinal, IntegerType), Literal(1))
Alias(subtract, generatePostAliasName)(
originalOrdinal.exprId,
originalOrdinal.qualifier)
}
val newGenerate =
generate.copy(generatorOutput = generate.generatorOutput.tail :+ originalOrdinal)
ProjectExec(
(generate.requiredChildOutput :+ ordinal) ++ generate.generatorOutput.tail,
newGenerate)
case Inline(_) =>
val unnestOutput = {
val struct = CreateStruct(generate.generatorOutput)
val alias = Alias(struct, generatePostAliasName)()
alias.toAttribute
}
val newGenerate = generate.copy(generatorOutput = Seq(unnestOutput))
val newOutput = generate.generatorOutput.zipWithIndex.map {
case (attr, i) =>
val getStructField = GetStructField(unnestOutput, i, Some(attr.name))
Alias(getStructField, generatePostAliasName)(attr.exprId, attr.qualifier)
}
ProjectExec(generate.requiredChildOutput ++ newOutput, newGenerate)
case _ => generate
}
} else {
generate
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,32 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
}

test("test inline function") {
// Literal: func(literal)
runQueryAndCompare(s"""
|SELECT inline(array(
| named_struct('c1', 0, 'c2', 1),
| named_struct('c1', 2, 'c2', null)));
|""".stripMargin) {
checkOperatorMatch[GenerateExecTransformer]
}

// CreateArray: func(array(col))
withTempView("t1") {
sql("""select * from values
sql("""SELECT * from values
| (named_struct('c1', 0, 'c2', 1)),
| (named_struct('c1', 2, 'c2', null)),
| (null)
|as tbl(a)
""".stripMargin).createOrReplaceTempView("t1")
runQueryAndCompare(s"""
|SELECT inline(array(a)) from t1;
|""".stripMargin) {
checkOperatorMatch[GenerateExecTransformer]
}
}

withTempView("t2") {
sql("""SELECT * from values
| array(
| named_struct('c1', 0, 'c2', 1),
| null,
Expand All @@ -800,13 +824,21 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
| named_struct('c1', 2, 'c2', 3)
| )
|as tbl(a)
""".stripMargin).createOrReplaceTempView("t1")
""".stripMargin).createOrReplaceTempView("t2")
runQueryAndCompare("""
|SELECT inline(a) from t1;
|SELECT inline(a) from t2;
|""".stripMargin) {
checkOperatorMatch[GenerateExecTransformer]
}
}

// Fallback for array(struct(...), null) literal.
runQueryAndCompare(s"""
|SELECT inline(array(
| named_struct('c1', 0, 'c2', 1),
| named_struct('c1', 2, 'c2', null),
| null));
|""".stripMargin)(_)
}

test("test array functions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,6 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite {

test("Literal Fallback") {
validateFallbackResult("SELECT struct(cast(null as struct<a: string>))")
validateFallbackResult("SELECT array(struct(1, 'a'), null)")
}
}
Loading

0 comments on commit 59ff500

Please sign in to comment.