From 2fede357fa86f2888741f50376bdcb31fbf02dd7 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 10 Dec 2025 12:08:45 -0800 Subject: [PATCH] commit --- .../catalyst/expressions/pipeOperators.scala | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala index b2bb949c9e5e..cfbd403d66fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala @@ -71,26 +71,44 @@ case object ValidateAndStripPipeExpressions extends Rule[LogicalPlan] { case node: LogicalPlan => node.resolveExpressions { case p: PipeExpression if p.child.resolved => - // Once the child expression is resolved, we can perform the necessary invariant checks - // and then remove this expression, replacing it with the child expression instead. - val firstAggregateFunction: Option[AggregateFunction] = findFirstAggregate(p.child) - if (p.isAggregate && firstAggregateFunction.isEmpty) { - throw QueryCompilationErrors - .pipeOperatorAggregateExpressionContainsNoAggregateFunction(p.child) - } else if (!p.isAggregate) { - // For non-aggregate clauses, only allow aggregate functions in SELECT. - // All other clauses (EXTEND, SET, etc.) disallow aggregates. - val aggregateAllowed = p.clause == PipeOperators.selectClause - if (!aggregateAllowed) { - firstAggregateFunction.foreach { a => - throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, p.clause) - } - } - } - p.child + validateAndStripPipeExpression(p, p.child) } } + /** + * Validates aggregate function constraints for a [[PipeExpression]] and returns the resolved + * child expression (stripping the [[PipeExpression]] wrapper). + * + * This method is shared between the fixed-point analyzer rule and the single-pass resolver. + * + * @param pipeExpression The [[PipeExpression]] containing metadata about the pipe clause. + * @param resolvedChild The resolved child expression to validate and return. + * @return The resolved child expression after validation. + */ + def validateAndStripPipeExpression( + pipeExpression: PipeExpression, + resolvedChild: Expression): Expression = { + val firstAggregateFunction: Option[AggregateFunction] = findFirstAggregate(resolvedChild) + if (pipeExpression.isAggregate && firstAggregateFunction.isEmpty) { + throw QueryCompilationErrors + .pipeOperatorAggregateExpressionContainsNoAggregateFunction(resolvedChild) + } + if (!pipeExpression.isAggregate) { + // For non-aggregate clauses, only allow aggregate functions in SELECT. + // All other clauses (EXTEND, SET, etc.) disallow aggregates. + val aggregateAllowed = pipeExpression.clause == PipeOperators.selectClause + if (!aggregateAllowed) { + firstAggregateFunction.foreach { a => + throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction( + a, + pipeExpression.clause + ) + } + } + } + resolvedChild + } + /** Returns the first aggregate function in the given expression, or None if not found. */ private def findFirstAggregate(e: Expression): Option[AggregateFunction] = e match { case a: AggregateFunction =>