-
Notifications
You must be signed in to change notification settings - Fork 608
[VL] Implement push partial agg thru expand rule #12052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
zhouyuan
wants to merge
1
commit into
apache:main
Choose a base branch
from
zhouyuan:wip_push_agg_thru_expand
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+312
−0
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
295 changes: 295 additions & 0 deletions
295
...ait/src/main/scala/org/apache/gluten/extension/columnar/PushPartialAggThroughExpand.scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,295 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.gluten.extension.columnar | ||
|
|
||
| import org.apache.gluten.config.GlutenConfig | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.execution.{ExpandExec, SparkPlan} | ||
| import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec | ||
| import org.apache.spark.sql.execution.aggregate.HashAggregateExec | ||
|
|
||
| /** | ||
| * Physical plan rule that pushes a partial [[HashAggregateExec]] through an [[ExpandExec]] so that | ||
| * pre-aggregation happens on the original (un-expanded) rows. | ||
| * | ||
| * Actual Q67 physical plan produced by Spark: | ||
| * | ||
| * HashAggregateExec (final) Exchange (shuffle by [grouping_keys..., spark_grouping_id]) | ||
| * HashAggregateExec (partial) <-- sees 9x expanded rows ExpandExec (9 projections for ROLLUP) | ||
| * Project BroadcastHashJoin ... | ||
| * | ||
| * After this rule: | ||
| * | ||
| * HashAggregateExec (final) Exchange HashAggregateExec (partial-merge) <-- merges per | ||
| * (grouping_keys, gid) ExpandExec (augmented) <-- pass-through + null-fill HashAggregateExec | ||
| * (partial) <-- pre-agg on original rows, no gid Project BroadcastHashJoin ... | ||
| */ | ||
| object PushPartialAggThroughExpand extends Rule[SparkPlan] with Logging { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you interested in contributing this optimization to Apache Spark? |
||
| override def apply(plan: SparkPlan): SparkPlan = { | ||
| if (!GlutenConfig.get.pushAggregateThroughExpandEnabled) { | ||
| return plan | ||
| } | ||
| logInfo( | ||
| s"PushPartialAggThroughExpand rule is enabled. Plan root: ${plan.getClass.getSimpleName}") | ||
|
|
||
| // Skip AdaptiveSparkPlanExec - it will be optimized during AQE execution | ||
| // This rule should run before AQE is inserted | ||
| plan match { | ||
| case _: AdaptiveSparkPlanExec => | ||
| logInfo("Skipping AdaptiveSparkPlanExec - rule should run before AQE") | ||
| return plan | ||
| case _ => | ||
| } | ||
|
|
||
| val result = plan.transformUp { | ||
| case agg: HashAggregateExec => | ||
| logInfo(s"Found HashAggregateExec: aggExprs=${agg.aggregateExpressions.size}, " + | ||
| s"modes=${agg.aggregateExpressions.map(_.mode).mkString(",")}, " + | ||
| s"child=${agg.child.getClass.getSimpleName}") | ||
|
|
||
| if ( | ||
| agg.aggregateExpressions.nonEmpty && | ||
| agg.aggregateExpressions.forall(_.mode == Partial) && | ||
| agg.child.isInstanceOf[ExpandExec] | ||
| ) { | ||
| logInfo(s"Found partial HashAggregate with Expand child") | ||
| val expand = agg.child.asInstanceOf[ExpandExec] | ||
| if (isEligible(agg, expand)) { | ||
| logInfo(s"Pushing partial aggregation through Expand") | ||
| rewrite(agg, expand) | ||
| } else { | ||
| logInfo(s"Not eligible for optimization") | ||
| agg | ||
| } | ||
| } else { | ||
| agg | ||
| } | ||
| case other => | ||
| other | ||
| } | ||
|
|
||
| logInfo(s"PushPartialAggThroughExpand rule finished") | ||
| result | ||
| } | ||
|
|
||
| // ------------------------------------------------------------------------- | ||
| // Eligibility | ||
| // ------------------------------------------------------------------------- | ||
|
|
||
| private def isEligible(agg: HashAggregateExec, expand: ExpandExec): Boolean = { | ||
| logInfo(s"Checking eligibility with ${agg.aggregateExpressions.size} aggregates") | ||
|
|
||
| // 1. All aggregate functions must be DeclarativeAggregate (decomposable). | ||
| // TypedImperativeAggregate has an opaque buffer we cannot split. | ||
| val allDeclarative = agg.aggregateExpressions.forall( | ||
| _.aggregateFunction.isInstanceOf[DeclarativeAggregate]) | ||
| logInfo(s"Check 1 - All DeclarativeAggregate: $allDeclarative") | ||
| if (!allDeclarative) { | ||
| logInfo(s"Failed: Not all aggregates are DeclarativeAggregate") | ||
| return false | ||
| } | ||
|
|
||
| val expandChildOutputSet = expand.child.outputSet | ||
| val childOutputStr = expandChildOutputSet.map(a => s"${a.name}#${a.exprId}").mkString(", ") | ||
| logInfo(s"Expand child output: $childOutputStr") | ||
|
|
||
| // Build a mapping from Expand output attributes to child attributes by name | ||
| // Expand creates new attributes but they correspond to child attributes by name | ||
| val expandOutputToChildAttr = expand.output.flatMap { | ||
| expandAttr => expandChildOutputSet.find(_.name == expandAttr.name).map(expandAttr -> _) | ||
| }.toMap | ||
|
|
||
| // 2. At least one grouping key must originate from Expand's child | ||
| // (i.e. something meaningful to pre-aggregate on before expansion). | ||
| val hasPreExpandKey = agg.groupingExpressions.exists { | ||
| e => | ||
| e.references.forall { | ||
| ref => | ||
| expandOutputToChildAttr.get(ref).exists( | ||
| childAttr => | ||
| expandChildOutputSet.contains(childAttr)) | ||
| } | ||
| } | ||
| logInfo(s"Check 2 - Has pre-expand grouping key: $hasPreExpandKey") | ||
| val groupingStr = agg.groupingExpressions.map { | ||
| e => | ||
| val refs = e.references.map(r => s"${r.name}#${r.exprId}").mkString(",") | ||
| s"${e.sql} [refs: $refs]" | ||
| }.mkString("; ") | ||
| logInfo(s"Grouping expressions: $groupingStr") | ||
| if (!hasPreExpandKey) { | ||
| logInfo(s"Failed: No grouping key from before Expand") | ||
| return false | ||
| } | ||
|
|
||
| // 3. The inputs to the aggregate functions (the measure expressions) must | ||
| // all come from Expand's child - Expand must pass them through unchanged. | ||
| // In Q67: ss_sales_price * ss_quantity are in every Expand projection | ||
| // unchanged, so this passes. If a measure referenced a null-filled | ||
| // column it would be wrong to pre-aggregate. | ||
| val expandInjectedAttrs = expand.output.toSet -- expandChildOutputSet | ||
| logInfo(s"Expand injected attrs: ${expandInjectedAttrs.map(_.name).mkString(", ")}") | ||
| val measureRefs = agg.aggregateExpressions.flatMap( | ||
| _.aggregateFunction.children.flatMap(_.references) | ||
| ) | ||
| logInfo(s"Measure references: ${measureRefs.map(_.name).mkString(", ")}") | ||
| val measureUsesInjected = measureRefs.exists(expandInjectedAttrs.contains) | ||
| logInfo(s"Check 3 - Measures use injected attrs: $measureUsesInjected") | ||
| if (measureUsesInjected) { | ||
| logInfo(s"Failed: Measure expressions reference Expand-injected attributes") | ||
| return false | ||
| } | ||
|
|
||
| // 4. No DISTINCT aggregates - distinct requires all raw values. | ||
| val hasDistinct = agg.aggregateExpressions.exists(_.isDistinct) | ||
| logInfo(s"Check 4 - Has DISTINCT: $hasDistinct") | ||
| if (hasDistinct) { | ||
| logInfo(s"Failed: Contains DISTINCT aggregates") | ||
| return false | ||
| } | ||
|
|
||
| logInfo(s"All eligibility checks passed!") | ||
| true | ||
| } | ||
|
|
||
| // ------------------------------------------------------------------------- | ||
| // Rewrite | ||
| // ------------------------------------------------------------------------- | ||
|
|
||
| private def rewrite( | ||
| partialAgg: HashAggregateExec, | ||
| expand: ExpandExec | ||
| ): SparkPlan = { | ||
| val expandChild = expand.child | ||
| val expandChildOutputSet = expandChild.outputSet | ||
|
|
||
| // Map from original child attributes to their names for lookup | ||
| val childAttrByName = expandChildOutputSet.map(a => a.name -> a).toMap | ||
|
|
||
| // Identify which attributes are used in grouping (dimensions) vs aggregation (measures) | ||
| val groupingAttrNames = partialAgg.groupingExpressions.flatMap(_.references).map(_.name).toSet | ||
| val measureAttrNames = partialAgg.aggregateExpressions | ||
| .flatMap(_.aggregateFunction.children.flatMap(_.references)) | ||
| .map(_.name).toSet | ||
|
|
||
| // Dimension attributes from child (used in grouping, not in measures) | ||
| val dimensionAttrs = expandChildOutputSet.filter( | ||
| a => | ||
| groupingAttrNames.contains(a.name) && !measureAttrNames.contains(a.name)) | ||
|
|
||
| logInfo(s"Dimension attrs: ${dimensionAttrs.map(_.name).mkString(", ")}") | ||
| logInfo(s"Measure attr names: ${measureAttrNames.mkString(", ")}") | ||
|
|
||
| // ---- Step 1: lower partial agg (new, below Expand) -------------------- | ||
| // Groups on dimension columns only, aggregates the measures | ||
|
|
||
| val lowerGroupingExprs = dimensionAttrs.toSeq.map(a => a: NamedExpression) | ||
|
|
||
| val lowerAggExprs: Seq[AggregateExpression] = | ||
| partialAgg.aggregateExpressions.map { | ||
| ae => ae.copy(mode = Partial, resultId = NamedExpression.newExprId) | ||
| } | ||
|
|
||
| val lowerDeclarativeAggs: Seq[DeclarativeAggregate] = | ||
| lowerAggExprs.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) | ||
|
|
||
| // Buffer attributes produced by the lower agg | ||
| val lowerBufferAttrs: Seq[Attribute] = | ||
| lowerDeclarativeAggs.flatMap(_.aggBufferAttributes) | ||
|
|
||
| // Result: dimension attrs + buffer attrs | ||
| val lowerResultExprs: Seq[NamedExpression] = | ||
| lowerGroupingExprs.map(_.toAttribute) ++ lowerBufferAttrs | ||
|
|
||
| val lowerAgg = HashAggregateExec( | ||
| requiredChildDistributionExpressions = None, | ||
| isStreaming = false, | ||
| numShufflePartitions = None, | ||
| groupingExpressions = lowerGroupingExprs, | ||
| aggregateExpressions = lowerAggExprs, | ||
| aggregateAttributes = lowerAggExprs.map(_.resultAttribute), | ||
| initialInputBufferOffset = 0, | ||
| resultExpressions = lowerResultExprs, | ||
| child = expandChild | ||
| ) | ||
|
|
||
| // Build mapping from child attr names to lower agg output attrs | ||
| val lowerAggOutputByName = lowerAgg.output.map(a => a.name -> a).toMap | ||
|
|
||
| // ---- Step 2: augmented ExpandExec (on top of lowerAgg) ---------------- | ||
| // Rewrite projections to reference lower agg output instead of original child | ||
|
|
||
| val augmentedProjections: Seq[Seq[Expression]] = | ||
| expand.projections.map { | ||
| proj => | ||
| val rewrittenProj = proj.flatMap { | ||
| case a: Attribute if measureAttrNames.contains(a.name) => | ||
| // Measure columns are now aggregated into buffers - skip them | ||
| None | ||
| case a: Attribute if lowerAggOutputByName.contains(a.name) => | ||
| // Map to corresponding attribute from lower agg output | ||
| Some(lowerAggOutputByName(a.name)) | ||
| case Literal(null, dt) => | ||
| // Null literals for null-padded columns - keep as is | ||
| Some(Literal(null, dt)) | ||
| case other => | ||
| // Other expressions (like gid literal) - keep as is | ||
| Some(other) | ||
| } | ||
| rewrittenProj ++ lowerBufferAttrs // Append buffer attrs to pass through | ||
| } | ||
|
|
||
| // Build new output schema: only dimension attrs (not measures) + buffer attrs | ||
| val augmentedOutput: Seq[Attribute] = expand.output.filterNot( | ||
| a => | ||
| measureAttrNames.contains(a.name)) ++ lowerBufferAttrs | ||
|
|
||
| val augmentedExpand = ExpandExec(augmentedProjections, augmentedOutput, lowerAgg) | ||
|
|
||
| // ---- Step 3: upper merge agg (replaces the original partial agg) ------ | ||
| // | ||
| // Grouping keys: same as original partialAgg (includes gid - which now | ||
| // exists in the augmented Expand output). | ||
| // | ||
| // Agg mode: Partial -> PartialMerge. In PartialMerge mode, Spark's | ||
| // DeclarativeAggregate reads from inputAggBufferAttributes instead of | ||
| // evaluating the original expressions against raw input rows. | ||
| // | ||
| // initialInputBufferOffset: number of positions in the input row before | ||
| // the agg buffer starts. The augmented Expand output is: | ||
| // [expand original output (grouping cols + gid), lowerBufferAttrs] | ||
| // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^ | ||
| // = expand.output.size positions buffer starts here | ||
| val upperInputBufferOffset = expand.output.size | ||
|
|
||
| val upperAggExprs: Seq[AggregateExpression] = | ||
| partialAgg.aggregateExpressions.map(_.copy(mode = PartialMerge)) | ||
|
|
||
| val upperAgg = partialAgg.copy( | ||
| aggregateExpressions = upperAggExprs, | ||
| aggregateAttributes = upperAggExprs.map(_.resultAttribute), | ||
| initialInputBufferOffset = upperInputBufferOffset, | ||
| child = augmentedExpand | ||
| ) | ||
|
|
||
| upperAgg | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems useful for vanilla Spark as well? Should we consider merging it into the upstream Spark?