Skip to content

Commit 2b09213

Browse files
committed
[SPARK-53996][SQL] Improve InferFiltersFromConstraints to infer filters from complex join expressions
1 parent 0e10341 commit 2b09213

File tree

3 files changed

+218
-46
lines changed

3 files changed

+218
-46
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,7 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] {
17351735
* Inner and LeftSemi joins.
17361736
*/
17371737
object InferFiltersFromConstraints extends Rule[LogicalPlan]
1738-
with PredicateHelper with ConstraintHelper {
1738+
with PredicateHelper with ConstraintHelper with ConstantPropagationHelper {
17391739

17401740
def apply(plan: LogicalPlan): LogicalPlan = {
17411741
if (conf.constraintPropagationEnabled) {
@@ -1786,11 +1786,45 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
17861786
left: LogicalPlan,
17871787
right: LogicalPlan,
17881788
conditionOpt: Option[Expression]): ExpressionSet = {
1789-
val baseConstraints = left.constraints.union(right.constraints)
1790-
.union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil)))
1789+
// Get constraints from children
1790+
val childConstraints = left.constraints.union(right.constraints)
1791+
1792+
// Apply constant propagation to the join condition using constraints from children
1793+
// This simplifies conditions like "t1.a = t2.b + 2" to "t1.a = 3" when t2.b = 1 is known
1794+
val simplifiedCondition = conditionOpt.map { condition =>
1795+
propagateConstants(condition, childConstraints)
1796+
}
1797+
1798+
// Collect all constraints: from children + simplified join condition
1799+
val baseConstraints = childConstraints
1800+
.union(ExpressionSet(simplifiedCondition.map(splitConjunctivePredicates).getOrElse(Nil)))
1801+
1802+
// Infer additional constraints through transitive closure
17911803
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
17921804
}
17931805

1806+
/**
1807+
* Apply constant propagation to an expression using a set of known constraints.
1808+
* This extracts attribute => literal mappings from the constraints and applies them
1809+
* to simplify the expression.
1810+
*
1811+
* @param expression the expression to simplify
1812+
* @param constraints known constraints (typically from WHERE clauses or join children)
1813+
* @return the simplified expression, or the original if no simplification is possible
1814+
*/
1815+
private def propagateConstants(
1816+
expression: Expression,
1817+
constraints: ExpressionSet): Expression = {
1818+
// Build attribute => literal map from constraints
1819+
val constantsMap = buildConstantsMap(constraints, nullIsFalse = true)
1820+
if (constantsMap.isEmpty) {
1821+
expression
1822+
} else {
1823+
// Apply the substitution
1824+
substituteConstants(expression, constantsMap)
1825+
}
1826+
}
1827+
17941828
private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
17951829
val newPredicates = constraints
17961830
.union(constructIsNotNullConstraints(constraints, plan.output))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 123 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,113 @@ object ConstantFolding extends Rule[LogicalPlan] {
117117
}
118118
}
119119

120+
/**
121+
* Helper trait for constant propagation logic that can be reused across different optimizer rules.
122+
*
123+
* This trait provides utilities to perform constant propagation: substituting attribute references
124+
* with their constant values when those values can be determined from equality predicates.
125+
*
126+
* Example usage in Filter:
127+
* {{{
128+
* Input: Filter(i = 5 AND j = i + 3)
129+
* Step 1: Extract mapping {i -> 5} from the equality predicate
130+
* Step 2: Substitute i with 5 in "j = i + 3" to get "j = 8"
131+
* Output: Filter(i = 5 AND j = 8)
132+
* }}}
133+
*/
134+
trait ConstantPropagationHelper {
135+
136+
/**
137+
* Substitute attributes with their constant values in an expression.
138+
*
139+
* This method performs the actual substitution of attribute references with literal constants
140+
* within binary comparison expressions.
141+
*
142+
* Example:
143+
* {{{
144+
* expression = (j = i + 3)
145+
* constantsMap = {i -> 5}
146+
* result = (j = 5 + 3)
147+
* }}}
148+
*
149+
* @param expression the expression to transform
150+
* @param constantsMap map from attributes to their constant literal values
151+
* @param excludePredicates set of predicates to exclude from transformation (the source equality
152+
* predicates themselves)
153+
* @return transformed expression with attributes replaced by literals
154+
*/
155+
protected def substituteConstants(
156+
expression: Expression,
157+
constantsMap: AttributeMap[Literal],
158+
excludePredicates: Set[Expression] = Set.empty): Expression = {
159+
if (constantsMap.isEmpty) {
160+
expression
161+
} else {
162+
expression.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
163+
case b: BinaryComparison if !excludePredicates.contains(b) => b transform {
164+
case a: AttributeReference => constantsMap.getOrElse(a, a)
165+
}
166+
}
167+
}
168+
}
169+
170+
/**
171+
* Build a map of attribute => literal from a set of constraints.
172+
*
173+
* This method scans a set of constraint expressions (typically from WHERE clauses or
174+
* join children) and extracts all equality predicates that map an attribute to a constant.
175+
*
176+
* Example:
177+
* {{{
178+
* constraints = Set(i = 5, j = 10, k > 3, i + j = 15)
179+
* result = {i -> 5, j -> 10} // only extracts simple equality predicates
180+
* }}}
181+
*/
182+
protected def buildConstantsMap(
183+
constraints: ExpressionSet,
184+
nullIsFalse: Boolean): AttributeMap[Literal] = {
185+
val mappings = constraints.flatMap(extractConstantMapping(_, nullIsFalse))
186+
AttributeMap(mappings.toSeq)
187+
}
188+
189+
/**
190+
* Extract a constant mapping from an equality predicate if possible.
191+
*
192+
* Examples:
193+
* {{{
194+
* i = 5 => Some((i, 5))
195+
* 5 = i => Some((i, 5))
196+
* i <=> 5 => Some((i, 5))
197+
* i + 1 = 5 => None (not a simple attribute-to-literal mapping)
198+
* i = j => None (no literal involved)
199+
* i > 5 => None (not an equality predicate)
200+
* }}}
201+
*/
202+
protected def extractConstantMapping(
203+
expr: Expression,
204+
nullIsFalse: Boolean): Option[(AttributeReference, Literal)] = expr match {
205+
case EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) =>
206+
Some(left -> right)
207+
case EqualTo(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) =>
208+
Some(right -> left)
209+
case EqualNullSafe(left: AttributeReference, right: Literal)
210+
if safeToReplace(left, nullIsFalse) =>
211+
Some(left -> right)
212+
case EqualNullSafe(left: Literal, right: AttributeReference)
213+
if safeToReplace(right, nullIsFalse) =>
214+
Some(right -> left)
215+
case _ => None
216+
}
217+
218+
// We need to take into account if an attribute is nullable and the context of the conjunctive
219+
// expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be
220+
// substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
221+
// NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
222+
// null result of the enclosed expression means.
223+
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean): Boolean =
224+
!ar.nullable || nullIsFalse
225+
}
226+
120227
/**
121228
* Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding
122229
* value in conjunctive [[Expression Expressions]]
@@ -131,7 +238,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
131238
* - Using this mapping, replace occurrence of the attributes with the corresponding constant values
132239
* in the AND node.
133240
*/
134-
object ConstantPropagation extends Rule[LogicalPlan] {
241+
object ConstantPropagation extends Rule[LogicalPlan] with ConstantPropagationHelper {
135242
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
136243
_.containsAllPatterns(LITERAL, FILTER, BINARY_COMPARISON), ruleId) {
137244
case f: Filter =>
@@ -162,31 +269,31 @@ object ConstantPropagation extends Rule[LogicalPlan] {
162269
* 2. AttributeMap: propagated mapping of attribute => constant
163270
*/
164271
private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
165-
: (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
272+
: (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
166273
condition match {
167274
case _ if !condition.containsAllPatterns(LITERAL, BINARY_COMPARISON) =>
168275
(None, AttributeMap.empty)
169-
case e @ EqualTo(left: AttributeReference, right: Literal)
170-
if safeToReplace(left, nullIsFalse) =>
171-
(None, AttributeMap(Map(left -> (right, e))))
172-
case e @ EqualTo(left: Literal, right: AttributeReference)
173-
if safeToReplace(right, nullIsFalse) =>
174-
(None, AttributeMap(Map(right -> (left, e))))
175-
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
176-
if safeToReplace(left, nullIsFalse) =>
177-
(None, AttributeMap(Map(left -> (right, e))))
178-
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
179-
if safeToReplace(right, nullIsFalse) =>
180-
(None, AttributeMap(Map(right -> (left, e))))
276+
case e: BinaryComparison =>
277+
// Try to extract constant mapping from this equality predicate
278+
extractConstantMapping(e, nullIsFalse) match {
279+
case Some((attr, literal)) => (None, AttributeMap(Map(attr -> (literal, e))))
280+
case None => (None, AttributeMap.empty)
281+
}
181282
case a: And =>
182283
val (newLeft, equalityPredicatesLeft) =
183284
traverse(a.left, replaceChildren = false, nullIsFalse)
184285
val (newRight, equalityPredicatesRight) =
185286
traverse(a.right, replaceChildren = false, nullIsFalse)
186287
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
187288
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
188-
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
189-
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
289+
// Convert map to just literals for substitution
290+
val constantsMap = AttributeMap(equalityPredicates.map {
291+
case (attr, (lit, _)) => attr -> lit
292+
})
293+
val predicatesToExclude: Set[Expression] = equalityPredicates.values.map(_._2).toSet
294+
Some(And(
295+
substituteConstants(newLeft.getOrElse(a.left), constantsMap, predicatesToExclude),
296+
substituteConstants(newRight.getOrElse(a.right), constantsMap, predicatesToExclude)))
190297
} else {
191298
if (newLeft.isDefined || newRight.isDefined) {
192299
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
@@ -211,26 +318,6 @@ object ConstantPropagation extends Rule[LogicalPlan] {
211318
(newChild.map(Not), AttributeMap.empty)
212319
case _ => (None, AttributeMap.empty)
213320
}
214-
215-
// We need to take into account if an attribute is nullable and the context of the conjunctive
216-
// expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be
217-
// substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
218-
// NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
219-
// null result of the enclosed expression means.
220-
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
221-
!ar.nullable || nullIsFalse
222-
223-
private def replaceConstants(
224-
condition: Expression,
225-
equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = {
226-
val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
227-
val predicates = equalityPredicates.values.map(_._2).toSet
228-
condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
229-
case b: BinaryComparison if !predicates.contains(b) => b transform {
230-
case a: AttributeReference => constantsMap.getOrElse(a, a)
231-
}
232-
}
233-
}
234321
}
235322

236323
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
201201
test("constraints should be inferred from aliased literals") {
202202
val originalLeft = testRelation.subquery("left").as("left")
203203
val optimizedLeft = testRelation.subquery("left")
204-
.where(IsNotNull($"a") && $"a" <=> 2).as("left")
204+
.where(IsNotNull($"a") && $"a" === 2).as("left")
205205

206206
val right = Project(Seq(Literal(2).as("two")),
207207
testRelation.subquery("right")).as("right")
@@ -213,6 +213,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
213213
comparePlans(Optimize.execute(original.analyze), correct.analyze)
214214
}
215215

216+
test("single inner join: infer constraints in condition with complex expressions") {
217+
val leftRelation = testRelation.subquery("x")
218+
val rightRelation = testRelation.subquery("y")
219+
220+
val left = leftRelation.where($"a" === 1)
221+
val right = rightRelation
222+
223+
testConstraintsAfterJoin(
224+
left,
225+
right,
226+
leftRelation.where(IsNotNull($"a") && $"a" === 1),
227+
rightRelation.where(IsNotNull($"b") && $"b" === Add(1, 2)),
228+
Inner,
229+
Some("y.b".attr === "x.a".attr + 2)
230+
)
231+
}
232+
216233
test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") {
217234
val x = testRelation.subquery("x")
218235
val y = testRelation.subquery("y")
@@ -274,17 +291,21 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
274291
val testRelation1 = LocalRelation($"a".int)
275292
val testRelation2 = LocalRelation($"b".long)
276293
val originalLeft = testRelation1.subquery("left")
277-
val originalRight = testRelation2.where($"b" === 1L).subquery("right")
294+
val originalRight = testRelation2.where($"b" > 100L).subquery("right")
278295

279-
val left = testRelation1.where(IsNotNull($"a") && $"a".cast(LongType) === 1L)
296+
val left = testRelation1.where(IsNotNull($"a") && $"a".cast(LongType) > 100L)
280297
.subquery("left")
281-
val right = testRelation2.where(IsNotNull($"b") && $"b" === 1L).subquery("right")
298+
val right = testRelation2.where(IsNotNull($"b") && $"b" > 100L).subquery("right")
282299

300+
// CAST(a AS BIGINT) = b with b > 100
301+
// Should infer: CAST(a AS BIGINT) > 100
283302
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
284303
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
285304
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
286305
}
287306

307+
// a = CAST(b AS INT) with b > 100
308+
// Should NOT infer new filter
288309
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
289310
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
290311
testConstraintsAfterJoin(
@@ -300,30 +321,60 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
300321
test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") {
301322
val testRelation1 = LocalRelation($"a".int)
302323
val testRelation2 = LocalRelation($"b".long)
303-
val originalLeft = testRelation1.where($"a" === 1).subquery("left")
324+
val originalLeft = testRelation1.where($"a" > 50).subquery("left")
304325
val originalRight = testRelation2.subquery("right")
305326

306-
val left = testRelation1.where(IsNotNull($"a") && $"a" === 1).subquery("left")
327+
val left = testRelation1.where(IsNotNull($"a") && $"a" > 50).subquery("left")
307328
val right = testRelation2.where(IsNotNull($"b")).subquery("right")
308329

330+
// CAST(a AS BIGINT) = b with a > 50
331+
// Should NOT infer new filter
309332
Seq(Some("left.a".attr.cast(LongType) === "right.b".attr),
310333
Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition =>
311334
testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition)
312335
}
313336

337+
// a = CAST(b AS INT) with a > 50
338+
// Should infer: CAST(b AS INT) > 50
314339
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
315340
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
316341
testConstraintsAfterJoin(
317342
originalLeft,
318343
originalRight,
319344
left,
320-
testRelation2.where(IsNotNull($"b") && $"b".attr.cast(IntegerType) === 1)
345+
testRelation2.where(IsNotNull($"b") && $"b".attr.cast(IntegerType) > 50)
321346
.subquery("right"),
322347
Inner,
323348
condition)
324349
}
325350
}
326351

352+
test("constant propagation through cast in join condition") {
353+
val testRelation1 = LocalRelation($"a".int)
354+
val testRelation2 = LocalRelation($"b".long)
355+
356+
val originalLeft = testRelation1.subquery("left")
357+
val originalRight = testRelation2.where($"b" === 1L).subquery("right")
358+
359+
val left = testRelation1.where(IsNotNull($"a") &&
360+
$"a" === Literal(1L).cast(IntegerType)).subquery("left")
361+
val right = testRelation2.where(IsNotNull($"b") && $"b" === 1L).subquery("right")
362+
363+
// Test constant propagation: b = 1 propagates through CAST
364+
// JOIN ON a = CAST(b AS INT) with b = 1
365+
// Should infer: a = CAST(1 AS INT)
366+
Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)),
367+
Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition =>
368+
testConstraintsAfterJoin(
369+
originalLeft,
370+
originalRight,
371+
left,
372+
right,
373+
Inner,
374+
condition)
375+
}
376+
}
377+
327378
test("SPARK-36978: IsNotNull constraints on structs should apply at the member field " +
328379
"instead of the root level nested type") {
329380
val structTestRelation = LocalRelation($"a".struct(StructType(

0 commit comments

Comments
 (0)