Skip to content

Commit 6d72c4c

Browse files
committed
[SPARK-53996][SQL] Improve InferFiltersFromConstraints to infer filters from complex join expressions
1 parent 0e10341 commit 6d72c4c

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,23 @@ trait ConstraintHelper {
5757

5858
/**
5959
* Infers an additional set of constraints from a given set of equality constraints.
60-
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
61-
* additional constraint of the form `b = 5`.
60+
*
61+
* This method performs two main types of inference:
62+
* 1. Attribute-to-attribute: For example, if an operator has constraints
63+
* of the form (`a = 5`, `a = b`), this returns an additional constraint of the form `b = 5`.
64+
* 2. Constant propagation: If the constraints contain both an equality to a constant and a
65+
* complex expression, such as `a = 5` and `b = a + 3`, it will infer `b = 5 + 3`
66+
* by substituting the constant into the expression.
67+
*
68+
* @param constraints The set of input constraints
69+
* @return A new set of inferred constraints
6270
*/
6371
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
6472
var inferredConstraints = ExpressionSet()
6573
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
6674
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
75+
76+
// Step 1: Infer attribute-to-attribute equalities
6777
predicates.foreach {
6878
case eq @ EqualTo(l: Attribute, r: Attribute) =>
6979
// Also remove EqualNullSafe with the same l and r to avoid Once strategy's idempotence
@@ -77,6 +87,43 @@ trait ConstraintHelper {
7787
inferredConstraints ++= replaceConstraints(predicates - eq - EqualNullSafe(l, r), l, r)
7888
case _ => // No inference
7989
}
90+
91+
// Step 2: Infer by constant substitution (e.g., a = 5, b = a + 3 => b = 5 + 3)
92+
val equalityPredicates = predicates.toSeq.flatMap {
93+
case e @ EqualTo(left: AttributeReference, right: Literal) => Some(((left, right), e))
94+
case e @ EqualTo(left: Literal, right: AttributeReference) => Some(((right, left), e))
95+
case _ => None
96+
}
97+
if (equalityPredicates.nonEmpty) {
98+
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
99+
val predicateSet = equalityPredicates.map(_._2).toSet
100+
def replaceConstantsInExpression(expression: Expression) = expression transform {
101+
case a: AttributeReference =>
102+
constantsMap.get(a) match {
103+
case Some(literal) => literal
104+
case None => a
105+
}
106+
}
107+
predicates.foreach { cond =>
108+
val replaced = cond transform {
109+
// attribute equality is handled above, no need to replace
110+
case e @ EqualTo(_: Attribute, _: Attribute) => e
111+
case e @ EqualTo(_: Cast, _: Attribute) => e
112+
case e @ EqualTo(_: Attribute, _: Cast) => e
113+
114+
case e @ EqualTo(_, _) if !predicateSet.contains(e) => replaceConstantsInExpression(e)
115+
}
116+
// Avoid inferring tautologies like 1 = 1
117+
val isTautology = replaced match {
118+
case EqualTo(left: Expression, right: Expression) if left.foldable && right.foldable =>
119+
left.eval() == right.eval()
120+
case _ => false
121+
}
122+
if (!constraints.contains(replaced) && !isTautology) {
123+
inferredConstraints += replaced
124+
}
125+
}
126+
}
80127
inferredConstraints -- constraints
81128
}
82129

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
6363
comparePlans(optimized, correctAnswer)
6464
}
6565

66+
test("filter: filter out constraints in condition with complex expression") {
67+
val originalQuery = testRelation.where($"a" === 1 && $"b" === $"a" + 2).analyze
68+
val correctAnswer = testRelation.where(IsNotNull($"a") && IsNotNull($"b") &&
69+
$"a" === 1 && $"b" === $"a" + 2 && $"b" === Add(1, 2)).analyze
70+
val optimized = Optimize.execute(originalQuery)
71+
comparePlans(optimized, correctAnswer)
72+
}
73+
6674
test("single inner join: filter out values on either side on equi-join keys") {
6775
val x = testRelation.subquery("x")
6876
val y = testRelation.subquery("y")
@@ -213,6 +221,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
213221
comparePlans(Optimize.execute(original.analyze), correct.analyze)
214222
}
215223

224+
test("single inner join: infer constraints in condition with complex expressions") {
225+
val leftRelation = testRelation.subquery("x")
226+
val rightRelation = testRelation.subquery("y")
227+
228+
val left = leftRelation.where($"a" === 1)
229+
val right = rightRelation
230+
231+
testConstraintsAfterJoin(
232+
left,
233+
right,
234+
leftRelation.where(IsNotNull($"a") && $"a" === 1),
235+
rightRelation.where(IsNotNull($"b") && $"b" === Add(1, 2)),
236+
Inner,
237+
Some("y.b".attr === "x.a".attr + 2)
238+
)
239+
}
240+
216241
test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
217242
val x = testRelation.subquery("x")
218243
val y = testRelation.subquery("y")

0 commit comments

Comments
 (0)