@@ -117,6 +117,113 @@ object ConstantFolding extends Rule[LogicalPlan] {
117117 }
118118}
119119
120+ /**
121+ * Helper trait for constant propagation logic that can be reused across different optimizer rules.
122+ *
123+ * This trait provides utilities to perform constant propagation: substituting attribute references
124+ * with their constant values when those values can be determined from equality predicates.
125+ *
126+ * Example usage in Filter:
127+ * {{{
128+ * Input: Filter(i = 5 AND j = i + 3)
129+ * Step 1: Extract mapping {i -> 5} from the equality predicate
130+ * Step 2: Substitute i with 5 in "j = i + 3" to get "j = 8"
131+ * Output: Filter(i = 5 AND j = 8)
132+ * }}}
133+ */
134+ trait ConstantPropagationHelper {
135+
136+ /**
137+ * Substitute attributes with their constant values in an expression.
138+ *
139+ * This method performs the actual substitution of attribute references with literal constants
140+ * within binary comparison expressions.
141+ *
142+ * Example:
143+ * {{{
144+ * expression = (j = i + 3)
145+ * constantsMap = {i -> 5}
146+ * result = (j = 5 + 3)
147+ * }}}
148+ *
149+ * @param expression the expression to transform
150+ * @param constantsMap map from attributes to their constant literal values
151+ * @param excludePredicates set of predicates to exclude from transformation (the source equality
152+ * predicates themselves)
153+ * @return transformed expression with attributes replaced by literals
154+ */
155+ protected def substituteConstants (
156+ expression : Expression ,
157+ constantsMap : AttributeMap [Literal ],
158+ excludePredicates : Set [Expression ] = Set .empty): Expression = {
159+ if (constantsMap.isEmpty) {
160+ expression
161+ } else {
162+ expression.transformWithPruning(_.containsPattern(BINARY_COMPARISON )) {
163+ case b : BinaryComparison if ! excludePredicates.contains(b) => b transform {
164+ case a : AttributeReference => constantsMap.getOrElse(a, a)
165+ }
166+ }
167+ }
168+ }
169+
170+ /**
171+ * Build a map of attribute => literal from a set of constraints.
172+ *
173+ * This method scans a set of constraint expressions (typically from WHERE clauses or
174+ * join children) and extracts all equality predicates that map an attribute to a constant.
175+ *
176+ * Example:
177+ * {{{
178+ * constraints = Set(i = 5, j = 10, k > 3, i + j = 15)
179+ * result = {i -> 5, j -> 10} // only extracts simple equality predicates
180+ * }}}
181+ */
182+ protected def buildConstantsMap (
183+ constraints : ExpressionSet ,
184+ nullIsFalse : Boolean ): AttributeMap [Literal ] = {
185+ val mappings = constraints.flatMap(extractConstantMapping(_, nullIsFalse))
186+ AttributeMap (mappings.toSeq)
187+ }
188+
189+ /**
190+ * Extract a constant mapping from an equality predicate if possible.
191+ *
192+ * Examples:
193+ * {{{
194+ * i = 5 => Some((i, 5))
195+ * 5 = i => Some((i, 5))
196+ * i <=> 5 => Some((i, 5))
197+ * i + 1 = 5 => None (not a simple attribute-to-literal mapping)
198+ * i = j => None (no literal involved)
199+ * i > 5 => None (not an equality predicate)
200+ * }}}
201+ */
202+ protected def extractConstantMapping (
203+ expr : Expression ,
204+ nullIsFalse : Boolean ): Option [(AttributeReference , Literal )] = expr match {
205+ case EqualTo (left : AttributeReference , right : Literal ) if safeToReplace(left, nullIsFalse) =>
206+ Some (left -> right)
207+ case EqualTo (left : Literal , right : AttributeReference ) if safeToReplace(right, nullIsFalse) =>
208+ Some (right -> left)
209+ case EqualNullSafe (left : AttributeReference , right : Literal )
210+ if safeToReplace(left, nullIsFalse) =>
211+ Some (left -> right)
212+ case EqualNullSafe (left : Literal , right : AttributeReference )
213+ if safeToReplace(right, nullIsFalse) =>
214+ Some (right -> left)
215+ case _ => None
216+ }
217+
218+ // We need to take into account if an attribute is nullable and the context of the conjunctive
219+ // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be
220+ // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
221+ // NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
222+ // null result of the enclosed expression means.
223+ private def safeToReplace (ar : AttributeReference , nullIsFalse : Boolean ): Boolean =
224+ ! ar.nullable || nullIsFalse
225+ }
226+
120227/**
121228 * Substitutes [[Attribute Attributes ]] which can be statically evaluated with their corresponding
122229 * value in conjunctive [[Expression Expressions ]]
@@ -131,7 +238,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
131238 * - Using this mapping, replace occurrence of the attributes with the corresponding constant values
132239 * in the AND node.
133240 */
134- object ConstantPropagation extends Rule [LogicalPlan ] {
241+ object ConstantPropagation extends Rule [LogicalPlan ] with ConstantPropagationHelper {
135242 def apply (plan : LogicalPlan ): LogicalPlan = plan.transformUpWithPruning(
136243 _.containsAllPatterns(LITERAL , FILTER , BINARY_COMPARISON ), ruleId) {
137244 case f : Filter =>
@@ -162,31 +269,31 @@ object ConstantPropagation extends Rule[LogicalPlan] {
162269 * 2. AttributeMap: propagated mapping of attribute => constant
163270 */
164271 private def traverse (condition : Expression , replaceChildren : Boolean , nullIsFalse : Boolean )
165- : (Option [Expression ], AttributeMap [(Literal , BinaryComparison )]) =
272+ : (Option [Expression ], AttributeMap [(Literal , BinaryComparison )]) =
166273 condition match {
167274 case _ if ! condition.containsAllPatterns(LITERAL , BINARY_COMPARISON ) =>
168275 (None , AttributeMap .empty)
169- case e @ EqualTo (left : AttributeReference , right : Literal )
170- if safeToReplace(left, nullIsFalse) =>
171- (None , AttributeMap (Map (left -> (right, e))))
172- case e @ EqualTo (left : Literal , right : AttributeReference )
173- if safeToReplace(right, nullIsFalse) =>
174- (None , AttributeMap (Map (right -> (left, e))))
175- case e @ EqualNullSafe (left : AttributeReference , right : Literal )
176- if safeToReplace(left, nullIsFalse) =>
177- (None , AttributeMap (Map (left -> (right, e))))
178- case e @ EqualNullSafe (left : Literal , right : AttributeReference )
179- if safeToReplace(right, nullIsFalse) =>
180- (None , AttributeMap (Map (right -> (left, e))))
276+ case e : BinaryComparison =>
277+ // Try to extract constant mapping from this equality predicate
278+ extractConstantMapping(e, nullIsFalse) match {
279+ case Some ((attr, literal)) => (None , AttributeMap (Map (attr -> (literal, e))))
280+ case None => (None , AttributeMap .empty)
281+ }
181282 case a : And =>
182283 val (newLeft, equalityPredicatesLeft) =
183284 traverse(a.left, replaceChildren = false , nullIsFalse)
184285 val (newRight, equalityPredicatesRight) =
185286 traverse(a.right, replaceChildren = false , nullIsFalse)
186287 val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
187288 val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
188- Some (And (replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
189- replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
289+ // Convert map to just literals for substitution
290+ val constantsMap = AttributeMap (equalityPredicates.map {
291+ case (attr, (lit, _)) => attr -> lit
292+ })
293+ val predicatesToExclude : Set [Expression ] = equalityPredicates.values.map(_._2).toSet
294+ Some (And (
295+ substituteConstants(newLeft.getOrElse(a.left), constantsMap, predicatesToExclude),
296+ substituteConstants(newRight.getOrElse(a.right), constantsMap, predicatesToExclude)))
190297 } else {
191298 if (newLeft.isDefined || newRight.isDefined) {
192299 Some (And (newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
@@ -211,26 +318,6 @@ object ConstantPropagation extends Rule[LogicalPlan] {
211318 (newChild.map(Not ), AttributeMap .empty)
212319 case _ => (None , AttributeMap .empty)
213320 }
214-
215- // We need to take into account if an attribute is nullable and the context of the conjunctive
216- // expression. E.g. `SELECT * FROM t WHERE NOT(c = 1 AND c + 1 = 1)` where attribute `c` can be
217- // substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
218- // NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
219- // null result of the enclosed expression means.
220- private def safeToReplace (ar : AttributeReference , nullIsFalse : Boolean ) =
221- ! ar.nullable || nullIsFalse
222-
223- private def replaceConstants (
224- condition : Expression ,
225- equalityPredicates : AttributeMap [(Literal , BinaryComparison )]): Expression = {
226- val constantsMap = AttributeMap (equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
227- val predicates = equalityPredicates.values.map(_._2).toSet
228- condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON )) {
229- case b : BinaryComparison if ! predicates.contains(b) => b transform {
230- case a : AttributeReference => constantsMap.getOrElse(a, a)
231- }
232- }
233- }
234321}
235322
236323/**
0 commit comments