From 59ff500b572ddccf5e563f599fc2d7b0dd1fa8d0 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Mon, 25 Mar 2024 17:07:03 +0800 Subject: [PATCH] [CORE] Pullout pre/post project for generate (#4952) --- .../clickhouse/CHSparkPlanExecApi.scala | 4 + .../velox/SparkPlanExecApiImpl.scala | 10 +- .../execution/GenerateExecTransformer.scala | 227 ++++++++---------- .../execution/TestOperator.scala | 38 ++- .../execution/VeloxLiteralSuite.scala | 1 + cpp/velox/substrait/SubstraitToVeloxExpr.cc | 30 ++- .../backendsapi/SparkPlanExecApi.scala | 6 +- .../GenerateExecTransformerBase.scala | 1 + .../columnar/PullOutPostProject.scala | 5 +- .../columnar/PullOutPreProject.scala | 7 +- .../RewriteSparkPlanRulesManager.scala | 3 +- .../utils/PullOutProjectHelper.scala | 9 +- 12 files changed, 191 insertions(+), 150 deletions(-) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 4b6ee1909313..781884ad57d8 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -766,4 +766,8 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { ): GenerateExecTransformerBase = { CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child) } + + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate + + override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate } diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala index 61ea50695db3..d7045c4e5f07 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, HashPartitioning, Partitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{BroadcastUtils, ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, VeloxColumnarWriteFilesExec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -677,4 +677,12 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { ): GenerateExecTransformerBase = { GenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child) } + + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = { + PullOutGenerateProjectHelper.pullOutPreProject(generate) + } + + override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = { + PullOutGenerateProjectHelper.pullOutPostProject(generate) + } } diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala index 5bdfba200f16..b865f31041ff 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala @@ -17,25 +17,24 @@ package io.glutenproject.execution import io.glutenproject.backendsapi.BackendsApiManager -import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, ExpressionNames} -import io.glutenproject.expression.ConverterUtils.FunctionConfig +import io.glutenproject.execution.GenerateExecTransformer.supportsGenerate import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater} -import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext -import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} +import io.glutenproject.substrait.expression.ExpressionNode import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} import io.glutenproject.substrait.rel.{RelBuilder, RelNode} +import io.glutenproject.utils.PullOutProjectHelper import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType} +import org.apache.spark.sql.types.IntegerType -import com.google.common.collect.Lists import com.google.protobuf.StringValue import scala.collection.JavaConverters._ +import scala.collection.mutable case class GenerateExecTransformer( generator: Generator, @@ -62,26 +61,12 @@ case class GenerateExecTransformer( override protected def doGeneratorValidate( generator: Generator, outer: Boolean): ValidationResult = { - if (outer) { - return ValidationResult.notOk(s"Velox backend does not support outer") - } - generator match { - case _: JsonTuple => - ValidationResult.notOk(s"Velox backend does not support this json_tuple") - case _: ExplodeBase => - ValidationResult.ok - case Inline(child) => - child match { - case AttributeReference(_, ArrayType(_: StructType, _), _, _) => - ValidationResult.ok - case _ => - // TODO: Support Literal/CreateArray. - ValidationResult.notOk( - s"Velox backend does not support inline with expression " + - s"${child.getClass.getSimpleName}.") - } - case _ => - ValidationResult.ok + if (!supportsGenerate(generator, outer)) { + ValidationResult.notOk( + s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}" + + s", outer: $outer") + } else { + ValidationResult.ok } } @@ -91,30 +76,13 @@ case class GenerateExecTransformer( generatorNode: ExpressionNode, validation: Boolean): RelNode = { val operatorId = context.nextOperatorId(this.nodeName) - - val newInput = if (!validation) { - applyPreProject(inputRel, context, operatorId) - } else { - // No need to validate the pre-projection. The generator output has been validated in - // doGeneratorValidate. - inputRel - } - - val generateRel = RelBuilder.makeGenerateRel( - newInput, + RelBuilder.makeGenerateRel( + inputRel, generatorNode, requiredChildOutputNodes.asJava, getExtensionNode(validation), context, operatorId) - - if (!validation) { - applyPostProject(generateRel, context, operatorId) - } else { - // No need to validate the post-projection on the native side as - // it only flattens the generator's output. - generateRel - } } private def getExtensionNode(validation: Boolean): AdvancedExtensionNode = { @@ -141,92 +109,95 @@ case class GenerateExecTransformer( getExtensionNodeForValidation } } +} - // Select child outputs and append generator input. - private def applyPreProject( - inputRel: RelNode, - context: SubstraitContext, - operatorId: Long - ): RelNode = { - val projectExpressions: Seq[ExpressionNode] = - child.output.indices - .map(ExpressionBuilder.makeSelection(_)) :+ - ExpressionConverter - .replaceWithExpressionTransformer( - generator.asInstanceOf[UnaryExpression].child, - child.output) - .doTransform(context.registeredFunction) +object GenerateExecTransformer { + def supportsGenerate(generator: Generator, outer: Boolean): Boolean = { + // TODO: supports outer and remove this param. + if (outer) { + false + } else { + generator match { + case _: Inline | _: ExplodeBase => + true + case _ => + false + } + } + } +} - RelBuilder.makeProjectRel( - inputRel, - projectExpressions.asJava, - context, - operatorId, - child.output.size) +object PullOutGenerateProjectHelper extends PullOutProjectHelper { + def pullOutPreProject(generate: GenerateExec): SparkPlan = { + if (GenerateExecTransformer.supportsGenerate(generate.generator, generate.outer)) { + val newGeneratorChildren = generate.generator match { + case _: Inline | _: ExplodeBase => + val expressionMap = new mutable.HashMap[Expression, NamedExpression]() + // The new child should be either the original Attribute, + // or an Alias to other expressions. + val generatorAttr = replaceExpressionWithAttribute( + generate.generator.asInstanceOf[UnaryExpression].child, + expressionMap, + replaceBoundReference = true) + val newGeneratorChild = if (expressionMap.isEmpty) { + // generator.child is Attribute + generatorAttr.asInstanceOf[Attribute] + } else { + // generator.child is other expression, e.g Literal/CreateArray/CreateMap + expressionMap.values.head + } + Seq(newGeneratorChild) + case _ => + // Unreachable. + throw new IllegalStateException( + s"Generator ${generate.generator.getClass.getSimpleName} is not supported.") + } + // Avoid using elimainateProjectList to create the project list + // because newGeneratorChild can be a duplicated Attribute in generate.child.output. + // The native side identifies the last field of projection as generator's input. + generate.copy( + generator = + generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator], + child = ProjectExec(generate.child.output ++ newGeneratorChildren, generate.child) + ) + } else { + generate + } } - // There are 3 types of CollectionGenerator in spark: Explode, PosExplode and Inline. - // Adds postProject for PosExplode and Inline. - private def applyPostProject( - generateRel: RelNode, - context: SubstraitContext, - operatorId: Long): RelNode = { - generator match { - case Inline(_) => - val requiredOutput = requiredChildOutputNodes.indices.map { - ExpressionBuilder.makeSelection(_) - } - val flattenStruct: Seq[ExpressionNode] = generatorOutput.indices.map { - i => - val selectionNode = ExpressionBuilder.makeSelection(requiredOutput.size) - selectionNode.addNestedChildIdx(i) - } - RelBuilder.makeProjectRel( - generateRel, - (requiredOutput ++ flattenStruct).asJava, - context, - operatorId, - 1 + requiredOutput.size // 1 stands for the inner struct field from array. - ) - case PosExplode(posExplodeChild) => - // Ordinal populated by Velox UnnestNode starts with 1. - // Need to substract 1 to align with Spark's output. - val unnestedSize = posExplodeChild.dataType match { - case _: MapType => 2 - case _: ArrayType => 1 - } - val subFunctionName = ConverterUtils.makeFuncName( - ExpressionNames.SUBTRACT, - Seq(LongType, LongType), - FunctionConfig.OPT) - val functionMap = context.registeredFunction - val addFunctionId = ExpressionBuilder.newScalarFunction(functionMap, subFunctionName) - val literalNode = ExpressionBuilder.makeLiteral(1L, LongType, false) - val ordinalNode = ExpressionBuilder.makeCast( - TypeBuilder.makeI32(false), - ExpressionBuilder.makeScalarFunction( - addFunctionId, - Lists.newArrayList( - ExpressionBuilder.makeSelection(requiredChildOutputNodes.size + unnestedSize), - literalNode), - ConverterUtils.getTypeNode(LongType, generator.elementSchema.head.nullable) - ), - true // Generated ordinal column shouldn't have null. - ) - val requiredChildNodes = - requiredChildOutputNodes.indices.map(ExpressionBuilder.makeSelection(_)) - val unnestColumns = (0 until unnestedSize) - .map(i => ExpressionBuilder.makeSelection(i + requiredChildOutputNodes.size)) - val generatorOutput: Seq[ExpressionNode] = - (requiredChildNodes :+ ordinalNode) ++ unnestColumns - RelBuilder.makeProjectRel( - generateRel, - generatorOutput.asJava, - context, - operatorId, - generatorOutput.size - ) - case _ => generateRel + def pullOutPostProject(generate: GenerateExec): SparkPlan = { + if (GenerateExecTransformer.supportsGenerate(generate.generator, generate.outer)) { + generate.generator match { + case PosExplode(_) => + val originalOrdinal = generate.generatorOutput.head + val ordinal = { + val subtract = Subtract(Cast(originalOrdinal, IntegerType), Literal(1)) + Alias(subtract, generatePostAliasName)( + originalOrdinal.exprId, + originalOrdinal.qualifier) + } + val newGenerate = + generate.copy(generatorOutput = generate.generatorOutput.tail :+ originalOrdinal) + ProjectExec( + (generate.requiredChildOutput :+ ordinal) ++ generate.generatorOutput.tail, + newGenerate) + case Inline(_) => + val unnestOutput = { + val struct = CreateStruct(generate.generatorOutput) + val alias = Alias(struct, generatePostAliasName)() + alias.toAttribute + } + val newGenerate = generate.copy(generatorOutput = Seq(unnestOutput)) + val newOutput = generate.generatorOutput.zipWithIndex.map { + case (attr, i) => + val getStructField = GetStructField(unnestOutput, i, Some(attr.name)) + Alias(getStructField, generatePostAliasName)(attr.exprId, attr.qualifier) + } + ProjectExec(generate.requiredChildOutput ++ newOutput, newGenerate) + case _ => generate + } + } else { + generate } } } diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index c81a6043094a..55f8ee6f5f53 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -787,8 +787,32 @@ class TestOperator extends VeloxWholeStageTransformerSuite { } test("test inline function") { + // Literal: func(literal) + runQueryAndCompare(s""" + |SELECT inline(array( + | named_struct('c1', 0, 'c2', 1), + | named_struct('c1', 2, 'c2', null))); + |""".stripMargin) { + checkOperatorMatch[GenerateExecTransformer] + } + + // CreateArray: func(array(col)) withTempView("t1") { - sql("""select * from values + sql("""SELECT * from values + | (named_struct('c1', 0, 'c2', 1)), + | (named_struct('c1', 2, 'c2', null)), + | (null) + |as tbl(a) + """.stripMargin).createOrReplaceTempView("t1") + runQueryAndCompare(s""" + |SELECT inline(array(a)) from t1; + |""".stripMargin) { + checkOperatorMatch[GenerateExecTransformer] + } + } + + withTempView("t2") { + sql("""SELECT * from values | array( | named_struct('c1', 0, 'c2', 1), | null, @@ -800,13 +824,21 @@ class TestOperator extends VeloxWholeStageTransformerSuite { | named_struct('c1', 2, 'c2', 3) | ) |as tbl(a) - """.stripMargin).createOrReplaceTempView("t1") + """.stripMargin).createOrReplaceTempView("t2") runQueryAndCompare(""" - |SELECT inline(a) from t1; + |SELECT inline(a) from t2; |""".stripMargin) { checkOperatorMatch[GenerateExecTransformer] } } + + // Fallback for array(struct(...), null) literal. + runQueryAndCompare(s""" + |SELECT inline(array( + | named_struct('c1', 0, 'c2', 1), + | named_struct('c1', 2, 'c2', null), + | null)); + |""".stripMargin)(_) } test("test array functions") { diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala index 557681558f56..52e122c2b8b9 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/VeloxLiteralSuite.scala @@ -136,5 +136,6 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite { test("Literal Fallback") { validateFallbackResult("SELECT struct(cast(null as struct))") + validateFallbackResult("SELECT array(struct(1, 'a'), null)") } } diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.cc b/cpp/velox/substrait/SubstraitToVeloxExpr.cc index f795f9c9ebd5..8699907de45b 100644 --- a/cpp/velox/substrait/SubstraitToVeloxExpr.cc +++ b/cpp/velox/substrait/SubstraitToVeloxExpr.cc @@ -49,15 +49,18 @@ MapVectorPtr makeMapVector(const VectorPtr& keyVector, const VectorPtr& valueVec valueVector); } -RowVectorPtr makeRowVector(const std::vector& children) { +RowVectorPtr makeRowVector( + const std::vector& children, + std::vector&& names, + size_t length, + facebook::velox::memory::MemoryPool* pool) { std::vector> types; types.resize(children.size()); for (int i = 0; i < children.size(); i++) { types[i] = children[i]->type(); } - const size_t vectorSize = children.empty() ? 0 : children.front()->size(); - auto rowType = ROW(std::move(types)); - return std::make_shared(children[0]->pool(), rowType, BufferPtr(nullptr), vectorSize, children); + auto rowType = ROW(std::move(names), std::move(types)); + return std::make_shared(pool, rowType, BufferPtr(nullptr), length, children); } ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool, const TypePtr& elementType) { @@ -73,7 +76,7 @@ MapVectorPtr makeEmptyMapVector(memory::MemoryPool* pool, const TypePtr& keyType } RowVectorPtr makeEmptyRowVector(memory::MemoryPool* pool) { - return makeRowVector({}); + return makeRowVector({}, {}, 0, pool); } template @@ -485,13 +488,20 @@ VectorPtr SubstraitVeloxExprConverter::literalsToVector( } RowVectorPtr SubstraitVeloxExprConverter::literalsToRowVector(const ::substrait::Expression::Literal& structLiteral) { - auto childSize = structLiteral.struct_().fields().size(); - if (childSize == 0) { + if (structLiteral.has_null()) { + VELOX_NYI("NULL for struct type is not supported."); + } + auto numFields = structLiteral.struct_().fields().size(); + if (numFields == 0) { return makeEmptyRowVector(pool_); } std::vector vectors; - vectors.reserve(structLiteral.struct_().fields().size()); - for (const auto& child : structLiteral.struct_().fields()) { + std::vector names; + vectors.reserve(numFields); + names.reserve(numFields); + for (auto i = 0; i < numFields; ++i) { + names.push_back("col_" + std::to_string(i)); + const auto& child = structLiteral.struct_().fields(i); auto typeCase = child.literal_type_case(); switch (typeCase) { case ::substrait::Expression_Literal::LiteralTypeCase::kIntervalDayToSecond: { @@ -530,7 +540,7 @@ RowVectorPtr SubstraitVeloxExprConverter::literalsToRowVector(const ::substrait: } } } - return makeRowVector(vectors); + return makeRowVector(vectors, std::move(names), 1, pool_); } core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala index 1b892ff59f3d..759f7cfad961 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{FileSourceScanExec, LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.{FileSourceScanExec, GenerateExec, LeafExecNode, SparkPlan} import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -653,4 +653,8 @@ trait SparkPlanExecApi { generatorOutput: Seq[Attribute], child: SparkPlan ): GenerateExecTransformerBase + + def genPreProjectForGenerate(generate: GenerateExec): SparkPlan + + def genPostProjectForGenerate(generate: GenerateExec): SparkPlan } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala index 285734f38534..f3e31346cd11 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformerBase.scala @@ -39,6 +39,7 @@ abstract class GenerateExecTransformerBase( generatorOutput: Seq[Attribute], child: SparkPlan) extends UnaryTransformSupport { + protected def doGeneratorValidate(generator: Generator, outer: Boolean): ValidationResult protected def getRelNode( diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala index 0a39ef8196b9..a77e063c5780 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPostProject.scala @@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression, WindowExpression} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.window.WindowExec @@ -104,6 +104,9 @@ object PullOutPostProject extends Rule[SparkPlan] with PullOutProjectHelper { window.copy(windowExpression = newWindowExpressions.asInstanceOf[Seq[NamedExpression]]) ProjectExec(window.child.output ++ postWindowExpressions, newWindow) + case generate: GenerateExec => + BackendsApiManager.getSparkPlanExecApiInstance.genPostProjectForGenerate(generate) + case _ => plan } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala index 440f609de92d..92d6c9fab87f 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala @@ -16,12 +16,13 @@ */ package io.glutenproject.extension.columnar +import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.utils.PullOutProjectHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec} +import org.apache.spark.sql.execution.{ExpandExec, GenerateExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression} import org.apache.spark.sql.execution.window.WindowExec @@ -189,6 +190,10 @@ object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper { child = ProjectExec( eliminateProjectList(expand.child.outputSet, expressionMap.values.toSeq), expand.child)) + + case generate: GenerateExec => + BackendsApiManager.getSparkPlanExecApiInstance.genPreProjectForGenerate(generate) + case _ => plan } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala index 8f3f01f9570b..b2591f048e84 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala @@ -54,6 +54,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: Seq[Rule[SparkPlan]]) extends R case _: FilterExec => true case _: FileSourceScanExec => true case _: ExpandExec => true + case _: GenerateExec => true case _ => false } } @@ -77,7 +78,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: Seq[Rule[SparkPlan]]) extends R // Some rewrite rules may generate new parent plan node, we should use transform to // rewrite the original plan. For example, PullOutPreProject and PullOutPostProject // will generate post-project plan node. - plan.transform { case p => rule.apply(p) } + plan.transformUp { case p => rule.apply(p) } } (rewrittenPlan, None) } catch { diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala index a519772fc745..543a5413cdcc 100644 --- a/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala +++ b/gluten-core/src/main/scala/io/glutenproject/utils/PullOutProjectHelper.scala @@ -30,8 +30,8 @@ trait PullOutProjectHelper { private val generatedNameIndex = new AtomicInteger(0) - protected def generatePreAliasName = s"_pre_${generatedNameIndex.getAndIncrement()}" - protected def generatePostAliasName = s"_post_${generatedNameIndex.getAndIncrement()}" + protected def generatePreAliasName: String = s"_pre_${generatedNameIndex.getAndIncrement()}" + protected def generatePostAliasName: String = s"_post_${generatedNameIndex.getAndIncrement()}" /** * The majority of Expressions only support Attribute and BoundReference when converting them into @@ -57,12 +57,13 @@ trait PullOutProjectHelper { protected def replaceExpressionWithAttribute( expr: Expression, - projectExprsMap: mutable.HashMap[Expression, NamedExpression]): Expression = + projectExprsMap: mutable.HashMap[Expression, NamedExpression], + replaceBoundReference: Boolean = false): Expression = expr match { case alias: Alias => projectExprsMap.getOrElseUpdate(alias.child.canonicalized, alias).toAttribute case attr: Attribute => attr - case e: BoundReference => e + case e: BoundReference if !replaceBoundReference => e case other => projectExprsMap .getOrElseUpdate(other.canonicalized, Alias(other, generatePreAliasName)())