Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE}
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -72,7 +74,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments,
addError = err => errors += err,
colPath = Seq(attr.name),
coerceNestedTypes)
coerceNestedTypes,
fromStar)
}

if (errors.nonEmpty) {
Expand Down Expand Up @@ -156,7 +159,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String],
coerceNestedTypes: Boolean = false): Expression = {
coerceNestedTypes: Boolean = false,
updateStar: Boolean = false): Expression = {

val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
assignment.key.semanticEquals(colExpr)
Expand All @@ -178,11 +182,30 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
val value = exactAssignments.head.value
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
resolvedValue
if (updateStar) {
val value = exactAssignments.head.value
col.dataType match {
case structType: StructType =>
// Expand assignments to leaf fields
val structAssignment =
applyNestedFieldAssignments(col, colExpr, value, addError, colPath,
coerceNestedTypes)

// Wrap with null check for missing source fields
fixNullExpansion(col, value, structType, structAssignment,
colPath, addError)
case _ =>
// For non-struct types, resolve directly
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath,
coerceMode)
}
} else {
val value = exactAssignments.head.value
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
}
} else {
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes)
}
Expand Down Expand Up @@ -211,7 +234,64 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
case otherType =>
addError(
"Updating nested fields is only supported for StructType but " +
s"'${colPath.quoted}' is of type $otherType")
s"'${colPath.quoted}' is of type $otherType")
colExpr
}
}

private def applyNestedFieldAssignments(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is like applyFieldAssignment above, but recurses to all nested fields

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applyFieldAssignments is not recursive? what's its behavior?

Copy link
Member Author

@szehon-ho szehon-ho Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, it is not , it just looks at missing assignments for the first level of schema and does some validation (like no two assignments to same field), and then uses TableOutputResolver to fill missing ones.

col: Attribute,
colExpr: Expression,
value: Expression,
addError: String => Unit,
colPath: Seq[String],
coerceNestedTyptes: Boolean): Expression = {

col.dataType match {
case structType: StructType =>
val fieldAttrs = DataTypeUtils.toAttributes(structType)

val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) =>
val fieldPath = colPath :+ fieldAttr.name
val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name))

// Try to find a corresponding field in the source value by name
val sourceFieldValue: Expression = value.dataType match {
case valueStructType: StructType =>
valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match {
case Some(matchingField) =>
// Found matching field in source, extract it
val fieldIndex = valueStructType.fieldIndex(matchingField.name)
GetStructField(value, fieldIndex, Some(matchingField.name))
case None =>
// Field doesn't exist in source, use target's current value with null check
TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath)
}
case _ =>
// Value is not a struct, cannot extract field
addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'")
Literal(null, fieldAttr.dataType)
}

// Recurse or resolve based on field type
fieldAttr.dataType match {
case nestedStructType: StructType =>
// Field is a struct, recurse
applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue,
addError, fieldPath, coerceNestedTyptes)
case _ =>
// Field is not a struct, resolve with TableOutputResolver
val coerceMode = if (coerceNestedTyptes) RECURSE else NONE
TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError,
fieldPath, coerceMode)
}
}
toNamedStruct(structType, updatedFieldExprs)

case otherType =>
addError(
"Updating nested fields is only supported for StructType but " +
s"'${colPath.quoted}' is of type $otherType")
colExpr
}
}
Expand All @@ -223,6 +303,77 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
CreateNamedStruct(namedStructExprs)
}

/**
* Checks if target struct has extra fields compared to source struct, recursively.
*/
private def hasExtraTargetFields(targetType: StructType, sourceType: DataType): Boolean = {
sourceType match {
case sourceStructType: StructType =>
targetType.fields.exists { targetField =>
sourceStructType.fields.find(f => conf.resolver(f.name, targetField.name)) match {
case Some(sourceField) =>
// Check nested structs recursively
(targetField.dataType, sourceField.dataType) match {
case (targetNested: StructType, sourceNested) =>
hasExtraTargetFields(targetNested, sourceNested)
case _ => false
}
case None => true // target has extra field not in source
}
}
case _ =>
Copy link
Contributor

@cloud-fan cloud-fan Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch means the types do not match and we will fail later?

Copy link
Member Author

@szehon-ho szehon-ho Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes actually its covered by the test https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala#L2915. The error is thrown earlier during schema evolution evaluation as you cant merge something non-struct with struct https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala#L1023

changed to throw exception instead of return false.

// Should be caught earlier
throw SparkException.internalError(
s"Source type must be StructType but found: $sourceType")
}
}

/**
* As UPDATE SET * assigns struct fields individually (preserving existing fields),
* this will lead to indiscriminate null expansion, ie, a struct is created where all
* fields are null. Wraps a struct assignment with a condition to return null
* if both conditions are true:
*
* - source struct is null
* - target struct is null OR target struct is same as source struct
*
* If the condition is not true, we preserve the original structure.
* This includes cases where the source was a struct of nulls,
* or there were any extra target fields (including null ones),
* both cases retain the assignment to a struct of nulls.
*
* @param col the target column attribute
* @param value the source value expression
* @param structType the target struct type
* @param structAssignment the struct assignment result to wrap
* @param colPath the column path for error reporting
* @param addError error reporting function
* @return the wrapped expression with null checks
*/
private def fixNullExpansion(
col: Attribute,
value: Expression,
structType: StructType,
structAssignment: Expression,
colPath: Seq[String],
addError: String => Unit): Expression = {
// As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for
// non-nullable column
if (!col.nullable) {
AssertNotNull(value)
} else {
val condition = if (hasExtraTargetFields(structType, value.dataType)) {
// extra target fields: return null iff source struct is null and target struct is null
And(IsNull(value), IsNull(col))
} else {
// schemas match: return null iff source struct is null
IsNull(value)
}

If(condition, Literal(null, structAssignment.dataType), structAssignment)
}
}

/**
* Checks whether assignments are aligned and compatible with table columns.
*
Expand Down
Loading