diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a2dced57c7153..f75e076f676cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1735,7 +1735,7 @@ object InferFiltersFromGenerate extends Rule[LogicalPlan] { * Inner and LeftSemi joins. */ object InferFiltersFromConstraints extends Rule[LogicalPlan] - with PredicateHelper with ConstraintHelper { + with PredicateHelper with ConstraintHelper with ConstantPropagationHelper { def apply(plan: LogicalPlan): LogicalPlan = { if (conf.constraintPropagationEnabled) { @@ -1786,11 +1786,45 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan] left: LogicalPlan, right: LogicalPlan, conditionOpt: Option[Expression]): ExpressionSet = { - val baseConstraints = left.constraints.union(right.constraints) - .union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil))) + // Get constraints from children + val childConstraints = left.constraints.union(right.constraints) + + // Apply constant propagation to the join condition using constraints from children + // This simplifies conditions like "t1.a = t2.b + 2" to "t1.a = 3" when t2.b = 1 is known + val simplifiedCondition = conditionOpt.map { condition => + propagateConstants(condition, childConstraints) + } + + // Collect all constraints: from children + simplified join condition + val baseConstraints = childConstraints + .union(ExpressionSet(simplifiedCondition.map(splitConjunctivePredicates).getOrElse(Nil))) + + // Infer additional constraints through transitive closure baseConstraints.union(inferAdditionalConstraints(baseConstraints)) } + /** + * Apply constant propagation to an expression using a set of known constraints. + * This extracts attribute => literal mappings from the constraints and applies them + * to simplify the expression. + * + * @param expression the expression to simplify + * @param constraints known constraints (typically from WHERE clauses or join children) + * @return the simplified expression, or the original if no simplification is possible + */ + private def propagateConstants( + expression: Expression, + constraints: ExpressionSet): Expression = { + // Build attribute => literal map from constraints + val constantsMap = buildConstantsMap(constraints, nullIsFalse = true) + if (constantsMap.isEmpty) { + expression + } else { + // Apply the substitution + substituteConstants(expression, constantsMap) + } + } + private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = { val newPredicates = constraints .union(constructIsNotNullConstraints(constraints, plan.output)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 71eb3e5ea2bd7..385d99137c549 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -117,6 +117,113 @@ object ConstantFolding extends Rule[LogicalPlan] { } } +/** + * Helper trait for constant propagation logic that can be reused across different optimizer rules. + * + * This trait provides utilities to perform constant propagation: substituting attribute references + * with their constant values when those values can be determined from equality predicates. + * + * Example usage in Filter: + * {{{ + * Input: Filter(i = 5 AND j = i + 3) + * Step 1: Extract mapping {i -> 5} from the equality predicate + * Step 2: Substitute i with 5 in "j = i + 3" to get "j = 8" + * Output: Filter(i = 5 AND j = 8) + * }}} + */ +trait ConstantPropagationHelper { + + /** + * Substitute attributes with their constant values in an expression. + * + * This method performs the actual substitution of attribute references with literal constants + * within binary comparison expressions. + * + * Example: + * {{{ + * expression = (j = i + 3) + * constantsMap = {i -> 5} + * result = (j = 5 + 3) + * }}} + * + * @param expression the expression to transform + * @param constantsMap map from attributes to their constant literal values + * @param excludePredicates set of predicates to exclude from transformation (the source equality + * predicates themselves) + * @return transformed expression with attributes replaced by literals + */ + protected def substituteConstants( + expression: Expression, + constantsMap: AttributeMap[Literal], + excludePredicates: Set[Expression] = Set.empty): Expression = { + if (constantsMap.isEmpty) { + expression + } else { + expression.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { + case b: BinaryComparison if !excludePredicates.contains(b) => b transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } + } + } + } + + /** + * Build a map of attribute => literal from a set of constraints. + * + * This method scans a set of constraint expressions (typically from WHERE clauses or + * join children) and extracts all equality predicates that map an attribute to a constant. + * + * Example: + * {{{ + * constraints = Set(i = 5, j = 10, k > 3, i + j = 15) + * result = {i -> 5, j -> 10} // only extracts simple equality predicates + * }}} + */ + protected def buildConstantsMap( + constraints: ExpressionSet, + nullIsFalse: Boolean): AttributeMap[Literal] = { + val mappings = constraints.flatMap(extractConstantMapping(_, nullIsFalse)) + AttributeMap(mappings.toSeq) + } + + /** + * Extract a constant mapping from an equality predicate if possible. + * + * Examples: + * {{{ + * i = 5 => Some((i, 5)) + * 5 = i => Some((i, 5)) + * i <=> 5 => Some((i, 5)) + * i + 1 = 5 => None (not a simple attribute-to-literal mapping) + * i = j => None (no literal involved) + * i > 5 => None (not an equality predicate) + * }}} + */ + protected def extractConstantMapping( + expr: Expression, + nullIsFalse: Boolean): Option[(AttributeReference, Literal)] = expr match { + case EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => + Some(left -> right) + case EqualTo(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => + Some(right -> left) + case EqualNullSafe(left: AttributeReference, right: Literal) + if safeToReplace(left, nullIsFalse) => + Some(left -> right) + case EqualNullSafe(left: Literal, right: AttributeReference) + if safeToReplace(right, nullIsFalse) => + Some(right -> left) + case _ => None + } + + // We need to take into account if an attribute is nullable and the context of the conjunctive + // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be + // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing + // NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a + // null result of the enclosed expression means. + private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean): Boolean = + !ar.nullable || nullIsFalse +} + /** * Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding * value in conjunctive [[Expression Expressions]] @@ -131,7 +238,7 @@ object ConstantFolding extends Rule[LogicalPlan] { * - Using this mapping, replace occurrence of the attributes with the corresponding constant values * in the AND node. */ -object ConstantPropagation extends Rule[LogicalPlan] { +object ConstantPropagation extends Rule[LogicalPlan] with ConstantPropagationHelper { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsAllPatterns(LITERAL, FILTER, BINARY_COMPARISON), ruleId) { case f: Filter => @@ -162,22 +269,16 @@ object ConstantPropagation extends Rule[LogicalPlan] { * 2. AttributeMap: propagated mapping of attribute => constant */ private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) - : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) = + : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) = condition match { case _ if !condition.containsAllPatterns(LITERAL, BINARY_COMPARISON) => (None, AttributeMap.empty) - case e @ EqualTo(left: AttributeReference, right: Literal) - if safeToReplace(left, nullIsFalse) => - (None, AttributeMap(Map(left -> (right, e)))) - case e @ EqualTo(left: Literal, right: AttributeReference) - if safeToReplace(right, nullIsFalse) => - (None, AttributeMap(Map(right -> (left, e)))) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) - if safeToReplace(left, nullIsFalse) => - (None, AttributeMap(Map(left -> (right, e)))) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) - if safeToReplace(right, nullIsFalse) => - (None, AttributeMap(Map(right -> (left, e)))) + case e: BinaryComparison => + // Try to extract constant mapping from this equality predicate + extractConstantMapping(e, nullIsFalse) match { + case Some((attr, literal)) => (None, AttributeMap(Map(attr -> (literal, e)))) + case None => (None, AttributeMap.empty) + } case a: And => val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false, nullIsFalse) @@ -185,8 +286,14 @@ object ConstantPropagation extends Rule[LogicalPlan] { traverse(a.right, replaceChildren = false, nullIsFalse) val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { - Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), - replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) + // Convert map to just literals for substitution + val constantsMap = AttributeMap(equalityPredicates.map { + case (attr, (lit, _)) => attr -> lit + }) + val predicatesToExclude: Set[Expression] = equalityPredicates.values.map(_._2).toSet + Some(And( + substituteConstants(newLeft.getOrElse(a.left), constantsMap, predicatesToExclude), + substituteConstants(newRight.getOrElse(a.right), constantsMap, predicatesToExclude))) } else { if (newLeft.isDefined || newRight.isDefined) { Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) @@ -211,26 +318,6 @@ object ConstantPropagation extends Rule[LogicalPlan] { (newChild.map(Not), AttributeMap.empty) case _ => (None, AttributeMap.empty) } - - // We need to take into account if an attribute is nullable and the context of the conjunctive - // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be - // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing - // NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a - // null result of the enclosed expression means. - private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = - !ar.nullable || nullIsFalse - - private def replaceConstants( - condition: Expression, - equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = { - val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit }) - val predicates = equalityPredicates.values.map(_._2).toSet - condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { - case b: BinaryComparison if !predicates.contains(b) => b transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } - } - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index d8d8a2b333bcd..b65503c4422d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -201,7 +201,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("constraints should be inferred from aliased literals") { val originalLeft = testRelation.subquery("left").as("left") val optimizedLeft = testRelation.subquery("left") - .where(IsNotNull($"a") && $"a" <=> 2).as("left") + .where(IsNotNull($"a") && $"a" === 2).as("left") val right = Project(Seq(Literal(2).as("two")), testRelation.subquery("right")).as("right") @@ -213,6 +213,23 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(Optimize.execute(original.analyze), correct.analyze) } + test("single inner join: infer constraints in condition with complex expressions") { + val leftRelation = testRelation.subquery("x") + val rightRelation = testRelation.subquery("y") + + val left = leftRelation.where($"a" === 1) + val right = rightRelation + + testConstraintsAfterJoin( + left, + right, + leftRelation.where(IsNotNull($"a") && $"a" === 1), + rightRelation.where(IsNotNull($"b") && $"b" === Add(1, 2)), + Inner, + Some("y.b".attr === "x.a".attr + 2) + ) + } + test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { val x = testRelation.subquery("x") val y = testRelation.subquery("y") @@ -274,17 +291,21 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val testRelation1 = LocalRelation($"a".int) val testRelation2 = LocalRelation($"b".long) val originalLeft = testRelation1.subquery("left") - val originalRight = testRelation2.where($"b" === 1L).subquery("right") + val originalRight = testRelation2.where($"b" > 100L).subquery("right") - val left = testRelation1.where(IsNotNull($"a") && $"a".cast(LongType) === 1L) + val left = testRelation1.where(IsNotNull($"a") && $"a".cast(LongType) > 100L) .subquery("left") - val right = testRelation2.where(IsNotNull($"b") && $"b" === 1L).subquery("right") + val right = testRelation2.where(IsNotNull($"b") && $"b" > 100L).subquery("right") + // CAST(a AS BIGINT) = b with b > 100 + // Should infer: CAST(a AS BIGINT) > 100 Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) } + // a = CAST(b AS INT) with b > 100 + // Should NOT infer new filter Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => testConstraintsAfterJoin( @@ -300,30 +321,60 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") { val testRelation1 = LocalRelation($"a".int) val testRelation2 = LocalRelation($"b".long) - val originalLeft = testRelation1.where($"a" === 1).subquery("left") + val originalLeft = testRelation1.where($"a" > 50).subquery("left") val originalRight = testRelation2.subquery("right") - val left = testRelation1.where(IsNotNull($"a") && $"a" === 1).subquery("left") + val left = testRelation1.where(IsNotNull($"a") && $"a" > 50).subquery("left") val right = testRelation2.where(IsNotNull($"b")).subquery("right") + // CAST(a AS BIGINT) = b with a > 50 + // Should NOT infer new filter Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) } + // a = CAST(b AS INT) with a > 50 + // Should infer: CAST(b AS INT) > 50 Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => testConstraintsAfterJoin( originalLeft, originalRight, left, - testRelation2.where(IsNotNull($"b") && $"b".attr.cast(IntegerType) === 1) + testRelation2.where(IsNotNull($"b") && $"b".attr.cast(IntegerType) > 50) .subquery("right"), Inner, condition) } } + test("constant propagation through cast in join condition") { + val testRelation1 = LocalRelation($"a".int) + val testRelation2 = LocalRelation($"b".long) + + val originalLeft = testRelation1.subquery("left") + val originalRight = testRelation2.where($"b" === 1L).subquery("right") + + val left = testRelation1.where(IsNotNull($"a") && + $"a" === Literal(1L).cast(IntegerType)).subquery("left") + val right = testRelation2.where(IsNotNull($"b") && $"b" === 1L).subquery("right") + + // Test constant propagation: b = 1 propagates through CAST + // JOIN ON a = CAST(b AS INT) with b = 1 + // Should infer: a = CAST(1 AS INT) + Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => + testConstraintsAfterJoin( + originalLeft, + originalRight, + left, + right, + Inner, + condition) + } + } + test("SPARK-36978: IsNotNull constraints on structs should apply at the member field " + "instead of the root level nested type") { val structTestRelation = LocalRelation($"a".struct(StructType(