From 984f1a7bb97ee6a97a5ea241c74722957a706066 Mon Sep 17 00:00:00 2001 From: ashrithb Date: Sun, 7 Dec 2025 01:12:50 -0500 Subject: [PATCH 1/2] [SPARK-54372][SQL] ANSI mode should reject avg/sum on timestamp types --- .../catalyst/analysis/AnsiTypeCoercion.scala | 17 ++- .../analysis/TypeCoercionHelper.scala | 117 ++++++++++++++++++ .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../analysis/AnsiTypeCoercionSuite.scala | 29 +++++ .../catalyst/analysis/TypeCoercionSuite.scala | 21 ++++ 5 files changed, 184 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index e23e7561f0e3..38ab4f56deb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -84,7 +84,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { InConversion :: PromoteStrings :: DecimalPrecision :: - FunctionArgumentConversion :: + AnsiFunctionArgumentConversion :: ConcatCoercion :: MapZipWithCoercion :: EltCoercion :: @@ -98,6 +98,21 @@ object AnsiTypeCoercion extends TypeCoercionBase { WindowFrameCoercion :: GetDateFieldOperations :: Nil) :: Nil + /** + * ANSI-compliant function argument type coercion rule. + * Unlike the default [[FunctionArgumentConversion]], this rule does NOT implicitly cast + * timestamp types to double for aggregate functions like Sum and Average. + * This ensures that in ANSI mode, applying these aggregate functions to timestamp types + * results in a proper type error, which is the expected behavior per SQL standards. + */ + object AnsiFunctionArgumentConversion extends TypeCoercionRule { + override val transform: PartialFunction[Expression, Expression] = { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + case withChildrenResolved => AnsiFunctionArgumentTypeCoercion(withChildrenResolved) + } + } + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala index 0e7d44e98bfb..8133d5e3d03a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionHelper.scala @@ -419,6 +419,123 @@ abstract class TypeCoercionHelper { } } + /** + * ANSI-compliant type coercion helper for function arguments. + * Unlike [[FunctionArgumentTypeCoercion]], this does NOT implicitly cast + * timestamp types to double for aggregate functions like Sum and Average. + * In ANSI mode, applying aggregate functions to timestamp types should result + * in a type error, as this is not standard SQL behavior. + */ + object AnsiFunctionArgumentTypeCoercion { + def apply(expression: Expression): Expression = expression match { + case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType))) + case None => a + } + + case c @ Concat(children) + if children.forall(c => ArrayType.acceptsType(c.dataType)) && + !haveSameType(c.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType))) + case None => c + } + + case aj @ ArrayJoin(arr, d, nr) + if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)). + acceptsType(arr.dataType) && + ArrayType.acceptsType(arr.dataType) => + val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull + implicitCast(arr, ArrayType(StringType, containsNull)) match { + case Some(castedArr) => ArrayJoin(castedArr, d, nr) + case None => aj + } + + case s @ Sequence(_, _, _, timeZoneId) + if !haveSameType(s.coercibleChildren.map(_.dataType)) => + val types = s.coercibleChildren.map(_.dataType) + findWiderCommonType(types) match { + case Some(widerDataType) => s.castChildrenTo(widerDataType) + case None => s + } + + case m @ MapConcat(children) + if children.forall(c => MapType.acceptsType(c.dataType)) && + !haveSameType(m.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType))) + case None => m + } + + case m @ CreateMap(children, _) + if m.keys.length == m.values.length && + (!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) => + val keyTypes = m.keys.map(_.dataType) + val newKeys = findWiderCommonType(keyTypes) match { + case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType)) + case None => m.keys + } + + val valueTypes = m.values.map(_.dataType) + val newValues = findWiderCommonType(valueTypes) match { + case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType)) + case None => m.values + } + + m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) + + // NOTE: Unlike FunctionArgumentTypeCoercion, we intentionally do NOT cast + // timestamp types to double for Sum and Average in ANSI mode. + // This allows the type check in Sum/Average to fail with a proper error message. + + // Coalesce should return the first non-null value, which could be any column + // from the list. So we need to make sure the return type is deterministic and + // compatible with every child column. + case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) => + val types = es.map(_.dataType) + findWiderCommonType(types) match { + case Some(finalDataType) => + Coalesce(es.map(castIfNotSameType(_, finalDataType))) + case None => + c + } + + // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if + // we need to truncate, but we should not promote one side to string if the other side is + // string.g + case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderTypeWithoutStringPromotion(types) match { + case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType))) + case None => g + } + + case l @ Least(children) if !haveSameType(l.inputTypesForMerging) => + val types = children.map(_.dataType) + findWiderTypeWithoutStringPromotion(types) match { + case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType))) + case None => l + } + + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => + NaNvl(l, Cast(r, DoubleType)) + case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => + NaNvl(Cast(l, DoubleType), r) + case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType)) + + case r: RandStr if r.length.dataType != IntegerType => + implicitCast(r.length, IntegerType).map { casted => + r.copy(length = casted) + }.getOrElse(r) + + case other => other + } + } + /** * Type coercion helper that matches against [[Concat]] expressions in order to type coerce * expression's children to expected types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index f094d7e93ec5..cb610b4f142a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -189,6 +189,7 @@ object RuleIdCollection { rulesNeedingIds = rulesNeedingIds ++ { // In the production code path, the following rules are run in CombinedTypeCoercionRule, and // hence we only need to add them for unit testing. + "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiFunctionArgumentConversion" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$GetDateFieldOperations" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$PromoteStringLiterals" :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index fa5027ce259d..f67d2dc1b0c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -1103,4 +1103,33 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true))) } + + test("SPARK-54372: ANSI mode should not implicitly cast timestamp to double for avg/sum") { + import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} + + val timestampAttr = AttributeReference("ts", TimestampType)() + + // In ANSI mode, AnsiFunctionArgumentConversion should NOT cast timestamp to double. + // The timestamp expression should remain unchanged, allowing the type check in + // Average/Sum to fail with a proper error message. + ruleTest( + AnsiTypeCoercion.AnsiFunctionArgumentConversion, + Average(timestampAttr), + Average(timestampAttr) // Should remain unchanged (not cast to double) + ) + + ruleTest( + AnsiTypeCoercion.AnsiFunctionArgumentConversion, + Sum(timestampAttr), + Sum(timestampAttr) // Should remain unchanged (not cast to double) + ) + + // Verify that numeric types still work correctly with avg + val intAttr = AttributeReference("i", IntegerType)() + ruleTest( + AnsiTypeCoercion.AnsiFunctionArgumentConversion, + Average(intAttr), + Average(intAttr) // Numeric types should pass through unchanged + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index e6a9690ad757..f2ea3eb5a0b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1796,6 +1796,27 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase { assert(wp1.isInstanceOf[Project]) assert(wp1.expressions.forall(!_.exists(_ == t1.output.head))) } + + test("SPARK-54372: Non-ANSI mode implicitly casts timestamp to double for avg/sum") { + import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} + + val timestampAttr = AttributeReference("ts", TimestampType)() + + // In non-ANSI mode (Hive compatibility), FunctionArgumentConversion casts timestamp to double. + // This is legacy behavior that allows avg(timestamp) to work, but returns a double (epoch + // seconds) instead of a timestamp. + ruleTest( + FunctionArgumentConversion, + Average(timestampAttr), + Average(Cast(timestampAttr, DoubleType)) // Timestamp should be cast to double + ) + + ruleTest( + FunctionArgumentConversion, + Sum(timestampAttr), + Sum(Cast(timestampAttr, DoubleType)) // Timestamp should be cast to double + ) + } } From 856efec7d1538e100363a7ff5663d69473568c0e Mon Sep 17 00:00:00 2001 From: ashrithb Date: Sun, 7 Dec 2025 01:21:29 -0500 Subject: [PATCH 2/2] Trigger CI