@@ -57,13 +57,23 @@ trait ConstraintHelper {
5757
5858  /**  
5959   * Infers an additional set of constraints from a given set of equality constraints. 
60-    * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an 
61-    * additional constraint of the form `b = 5`. 
60+    * 
61+    * This method performs two main types of inference: 
62+    *  1. Attribute-to-attribute: For example, if an operator has constraints 
63+    *     of the form (`a = 5`, `a = b`), this returns an additional constraint of the form `b = 5`. 
64+    *  2. Constant propagation: If the constraints contain both an equality to a constant and a 
65+    *     complex expression, such as `a = 5` and `b = a + 3`, it will infer `b = 5 + 3` 
66+    *     by substituting the constant into the expression. 
67+    * 
68+    * @param  constraints  The set of input constraints 
69+    * @return  A new set of inferred constraints 
6270   */  
6371  def  inferAdditionalConstraints (constraints : ExpressionSet ):  ExpressionSet  =  {
6472    var  inferredConstraints  =  ExpressionSet ()
6573    //  IsNotNull should be constructed by `constructIsNotNullConstraints`.
6674    val  predicates  =  constraints.filterNot(_.isInstanceOf [IsNotNull ])
75+ 
76+     //  Step 1: Infer attribute-to-attribute equalities
6777    predicates.foreach {
6878      case  eq @  EqualTo (l : Attribute , r : Attribute ) => 
6979        //  Also remove EqualNullSafe with the same l and r to avoid Once strategy's idempotence
@@ -77,6 +87,43 @@ trait ConstraintHelper {
7787        inferredConstraints ++=  replaceConstraints(predicates -  eq -  EqualNullSafe (l, r), l, r)
7888      case  _ =>  //  No inference
7989    }
90+ 
91+     //  Step 2: Infer by constant substitution (e.g., a = 5, b = a + 3 => b = 5 + 3)
92+     val  equalityPredicates  =  predicates.toSeq.flatMap {
93+       case  e @  EqualTo (left : AttributeReference , right : Literal ) =>  Some (((left, right), e))
94+       case  e @  EqualTo (left : Literal , right : AttributeReference ) =>  Some (((right, left), e))
95+       case  _ =>  None 
96+     }
97+     if  (equalityPredicates.nonEmpty) {
98+       val  constantsMap  =  AttributeMap (equalityPredicates.map(_._1))
99+       val  predicateSet  =  equalityPredicates.map(_._2).toSet
100+       def  replaceConstantsInExpression (expression : Expression ) =  expression transform {
101+         case  a : AttributeReference  => 
102+           constantsMap.get(a) match  {
103+             case  Some (literal) =>  literal
104+             case  None  =>  a
105+           }
106+       }
107+       predicates.foreach { cond => 
108+         val  replaced  =  cond transform {
109+           //  attribute equality is handled above, no need to replace
110+           case  e @  EqualTo (_ : Attribute , _ : Attribute ) =>  e
111+           case  e @  EqualTo (_ : Cast , _ : Attribute ) =>  e
112+           case  e @  EqualTo (_ : Attribute , _ : Cast ) =>  e
113+ 
114+           case  e @  EqualTo (_, _) if  ! predicateSet.contains(e) =>  replaceConstantsInExpression(e)
115+         }
116+         //  Avoid inferring tautologies like 1 = 1
117+         val  isTautology  =  replaced match  {
118+           case  EqualTo (left : Expression , right : Expression ) if  left.foldable &&  right.foldable => 
119+             left.eval() ==  right.eval()
120+           case  _ =>  false 
121+         }
122+         if  (! constraints.contains(replaced) &&  ! isTautology) {
123+           inferredConstraints +=  replaced
124+         }
125+       }
126+     }
80127    inferredConstraints --  constraints
81128  }
82129
0 commit comments