Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
InConversion ::
PromoteStrings ::
DecimalPrecision ::
FunctionArgumentConversion ::
AnsiFunctionArgumentConversion ::
ConcatCoercion ::
MapZipWithCoercion ::
EltCoercion ::
Expand All @@ -98,6 +98,21 @@ object AnsiTypeCoercion extends TypeCoercionBase {
WindowFrameCoercion ::
GetDateFieldOperations :: Nil) :: Nil

/**
* ANSI-compliant function argument type coercion rule.
* Unlike the default [[FunctionArgumentConversion]], this rule does NOT implicitly cast
* timestamp types to double for aggregate functions like Sum and Average.
* This ensures that in ANSI mode, applying these aggregate functions to timestamp types
* results in a proper type error, which is the expected behavior per SQL standards.
*/
object AnsiFunctionArgumentConversion extends TypeCoercionRule {
override val transform: PartialFunction[Expression, Expression] = {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
case withChildrenResolved => AnsiFunctionArgumentTypeCoercion(withChildrenResolved)
}
}

val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,123 @@ abstract class TypeCoercionHelper {
}
}

/**
* ANSI-compliant type coercion helper for function arguments.
* Unlike [[FunctionArgumentTypeCoercion]], this does NOT implicitly cast
* timestamp types to double for aggregate functions like Sum and Average.
* In ANSI mode, applying aggregate functions to timestamp types should result
* in a type error, as this is not standard SQL behavior.
*/
object AnsiFunctionArgumentTypeCoercion {
def apply(expression: Expression): Expression = expression match {
case a @ CreateArray(children, _) if !haveSameType(children.map(_.dataType)) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_, finalDataType)))
case None => a
}

case c @ Concat(children)
if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(c.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(castIfNotSameType(_, finalDataType)))
case None => c
}

case aj @ ArrayJoin(arr, d, nr)
if !AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)).
acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
implicitCast(arr, ArrayType(StringType, containsNull)) match {
case Some(castedArr) => ArrayJoin(castedArr, d, nr)
case None => aj
}

case s @ Sequence(_, _, _, timeZoneId)
if !haveSameType(s.coercibleChildren.map(_.dataType)) =>
val types = s.coercibleChildren.map(_.dataType)
findWiderCommonType(types) match {
case Some(widerDataType) => s.castChildrenTo(widerDataType)
case None => s
}

case m @ MapConcat(children)
if children.forall(c => MapType.acceptsType(c.dataType)) &&
!haveSameType(m.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => MapConcat(children.map(castIfNotSameType(_, finalDataType)))
case None => m
}

case m @ CreateMap(children, _)
if m.keys.length == m.values.length &&
(!haveSameType(m.keys.map(_.dataType)) || !haveSameType(m.values.map(_.dataType))) =>
val keyTypes = m.keys.map(_.dataType)
val newKeys = findWiderCommonType(keyTypes) match {
case Some(finalDataType) => m.keys.map(castIfNotSameType(_, finalDataType))
case None => m.keys
}

val valueTypes = m.values.map(_.dataType)
val newValues = findWiderCommonType(valueTypes) match {
case Some(finalDataType) => m.values.map(castIfNotSameType(_, finalDataType))
case None => m.values
}

m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })

// NOTE: Unlike FunctionArgumentTypeCoercion, we intentionally do NOT cast
// timestamp types to double for Sum and Average in ANSI mode.
// This allows the type check in Sum/Average to fail with a proper error message.

// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(c.inputTypesForMerging) =>
val types = es.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) =>
Coalesce(es.map(castIfNotSameType(_, finalDataType)))
case None =>
c
}

// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
// we need to truncate, but we should not promote one side to string if the other side is
// string.g
case g @ Greatest(children) if !haveSameType(g.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Greatest(children.map(castIfNotSameType(_, finalDataType)))
case None => g
}

case l @ Least(children) if !haveSameType(l.inputTypesForMerging) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Least(children.map(castIfNotSameType(_, finalDataType)))
case None => l
}

case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)
case NaNvl(l, r) if r.dataType == NullType => NaNvl(l, Cast(r, l.dataType))

case r: RandStr if r.length.dataType != IntegerType =>
implicitCast(r.length, IntegerType).map { casted =>
r.copy(length = casted)
}.getOrElse(r)

case other => other
}
}

/**
* Type coercion helper that matches against [[Concat]] expressions in order to type coerce
* expression's children to expected types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ object RuleIdCollection {
rulesNeedingIds = rulesNeedingIds ++ {
// In the production code path, the following rules are run in CombinedTypeCoercionRule, and
// hence we only need to add them for unit testing.
"org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiFunctionArgumentConversion" ::
"org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$DateTimeOperations" ::
"org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$GetDateFieldOperations" ::
"org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$PromoteStringLiterals" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1103,4 +1103,33 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase {
shouldNotCast(ArrayType(ArrayType(IntegerType)),
AbstractArrayType(StringTypeWithCollation(supportsTrimCollation = true)))
}

test("SPARK-54372: ANSI mode should not implicitly cast timestamp to double for avg/sum") {
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}

val timestampAttr = AttributeReference("ts", TimestampType)()

// In ANSI mode, AnsiFunctionArgumentConversion should NOT cast timestamp to double.
// The timestamp expression should remain unchanged, allowing the type check in
// Average/Sum to fail with a proper error message.
ruleTest(
AnsiTypeCoercion.AnsiFunctionArgumentConversion,
Average(timestampAttr),
Average(timestampAttr) // Should remain unchanged (not cast to double)
)

ruleTest(
AnsiTypeCoercion.AnsiFunctionArgumentConversion,
Sum(timestampAttr),
Sum(timestampAttr) // Should remain unchanged (not cast to double)
)

// Verify that numeric types still work correctly with avg
val intAttr = AttributeReference("i", IntegerType)()
ruleTest(
AnsiTypeCoercion.AnsiFunctionArgumentConversion,
Average(intAttr),
Average(intAttr) // Numeric types should pass through unchanged
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,27 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
assert(wp1.isInstanceOf[Project])
assert(wp1.expressions.forall(!_.exists(_ == t1.output.head)))
}

test("SPARK-54372: Non-ANSI mode implicitly casts timestamp to double for avg/sum") {
import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum}

val timestampAttr = AttributeReference("ts", TimestampType)()

// In non-ANSI mode (Hive compatibility), FunctionArgumentConversion casts timestamp to double.
// This is legacy behavior that allows avg(timestamp) to work, but returns a double (epoch
// seconds) instead of a timestamp.
ruleTest(
FunctionArgumentConversion,
Average(timestampAttr),
Average(Cast(timestampAttr, DoubleType)) // Timestamp should be cast to double
)

ruleTest(
FunctionArgumentConversion,
Sum(timestampAttr),
Sum(Cast(timestampAttr, DoubleType)) // Timestamp should be cast to double
)
}
}


Expand Down