Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,44 @@ case object ValidateAndStripPipeExpressions extends Rule[LogicalPlan] {
case node: LogicalPlan =>
node.resolveExpressions {
case p: PipeExpression if p.child.resolved =>
// Once the child expression is resolved, we can perform the necessary invariant checks
// and then remove this expression, replacing it with the child expression instead.
val firstAggregateFunction: Option[AggregateFunction] = findFirstAggregate(p.child)
if (p.isAggregate && firstAggregateFunction.isEmpty) {
throw QueryCompilationErrors
.pipeOperatorAggregateExpressionContainsNoAggregateFunction(p.child)
} else if (!p.isAggregate) {
// For non-aggregate clauses, only allow aggregate functions in SELECT.
// All other clauses (EXTEND, SET, etc.) disallow aggregates.
val aggregateAllowed = p.clause == PipeOperators.selectClause
if (!aggregateAllowed) {
firstAggregateFunction.foreach { a =>
throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, p.clause)
}
}
}
p.child
validateAndStripPipeExpression(p, p.child)
}
}

/**
* Validates aggregate function constraints for a [[PipeExpression]] and returns the resolved
* child expression (stripping the [[PipeExpression]] wrapper).
*
* This method is shared between the fixed-point analyzer rule and the single-pass resolver.
*
* @param pipeExpression The [[PipeExpression]] containing metadata about the pipe clause.
* @param resolvedChild The resolved child expression to validate and return.
* @return The resolved child expression after validation.
*/
def validateAndStripPipeExpression(
pipeExpression: PipeExpression,
resolvedChild: Expression): Expression = {
val firstAggregateFunction: Option[AggregateFunction] = findFirstAggregate(resolvedChild)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, we should extract this outside this method and then pass it as an arg, because it is recursive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review, I looked at this, it turns out we are always going to want to use the same expression-finding logic, so should be OK to leave this hard-coded in here.

if (pipeExpression.isAggregate && firstAggregateFunction.isEmpty) {
throw QueryCompilationErrors
.pipeOperatorAggregateExpressionContainsNoAggregateFunction(resolvedChild)
}
if (!pipeExpression.isAggregate) {
// For non-aggregate clauses, only allow aggregate functions in SELECT.
// All other clauses (EXTEND, SET, etc.) disallow aggregates.
val aggregateAllowed = pipeExpression.clause == PipeOperators.selectClause
if (!aggregateAllowed) {
firstAggregateFunction.foreach { a =>
throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(
a,
pipeExpression.clause
)
}
}
}
resolvedChild
}

/** Returns the first aggregate function in the given expression, or None if not found. */
private def findFirstAggregate(e: Expression): Option[AggregateFunction] = e match {
case a: AggregateFunction =>
Expand Down