diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 38b5b48f77b..0c700d9de32 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1199,13 +1199,31 @@ private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBin
return arrayType;
}
final RelDataType elementType = opBinding.collectOperandTypes().get(1);
+ requireNonNull(componentType, () -> "componentType of " + arrayType);
+
RelDataType type =
opBinding.getTypeFactory().leastRestrictive(
ImmutableList.of(componentType, elementType));
+ requireNonNull(type, "inferred array element type");
+
if (elementType.isNullable()) {
type = opBinding.getTypeFactory().createTypeWithNullability(type, true);
}
- requireNonNull(type, "inferred array element type");
+
+ // make explicit CAST for array elements and inserted element to the biggest type
+ // if array component type not equals to inserted element type
+ if (!componentType.equalsSansFieldNames(elementType)) {
+ // 0, 1 is the operand index to be CAST
+ // For array_append/array_prepend, 0 is the array arg and 1 is the inserted element
+ if (componentType.equalsSansFieldNames(type)) {
+ SqlValidatorUtil.
+ adjustTypeForArrayFunctions(type, opBinding, 1);
+ } else {
+ SqlValidatorUtil.
+ adjustTypeForArrayFunctions(type, opBinding, 0);
+ }
+ }
+
return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, arrayType.isNullable());
}
@@ -1282,14 +1300,32 @@ private static RelDataType arrayInsertReturnType(SqlOperatorBinding opBinding) {
final RelDataType arrayType = opBinding.collectOperandTypes().get(0);
final RelDataType componentType = arrayType.getComponentType();
final RelDataType elementType = opBinding.collectOperandTypes().get(2);
+ requireNonNull(componentType, () -> "componentType of " + arrayType);
+
// we don't need to do leastRestrictive on componentType and elementType,
// because in operand checker we limit the elementType must equals array component type.
// So we use componentType directly.
- RelDataType type = componentType;
+ RelDataType type =
+ opBinding.getTypeFactory().leastRestrictive(
+ ImmutableList.of(componentType, elementType));
+ requireNonNull(type, "inferred array element type");
+
if (elementType.isNullable()) {
type = opBinding.getTypeFactory().createTypeWithNullability(type, true);
}
- requireNonNull(type, "inferred array element type");
+ // make explicit CAST for array elements and inserted element to the biggest type
+ // if array component type not equals to inserted element type
+ if (!componentType.equalsSansFieldNames(elementType)) {
+ // 0, 2 is the operand index to be CAST
+ // For array_insert, 0 is the array arg and 2 is the inserted element
+ if (componentType.equalsSansFieldNames(type)) {
+ SqlValidatorUtil.
+ adjustTypeForArrayFunctions(type, opBinding, 2);
+ } else {
+ SqlValidatorUtil.
+ adjustTypeForArrayFunctions(type, opBinding, 0);
+ }
+ }
return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, arrayType.isNullable());
}
diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
index fd18c2d7b18..9aeec2da20f 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
@@ -37,6 +37,7 @@
import org.apache.calcite.schema.Table;
import org.apache.calcite.schema.impl.AbstractSchema;
import org.apache.calcite.schema.impl.AbstractTable;
+import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlDataTypeSpec;
@@ -1326,6 +1327,51 @@ public static void adjustTypeForArrayConstructor(
}
}
+ /**
+ * Adjusts the types of specified operands in an array operation to match a given target type.
+ * This is particularly useful in the context of SQL operations involving array functions,
+ * where it's necessary to ensure that all operands have consistent types for the operation
+ * to be valid.
+ *
+ *
This method operates on the assumption that the operands to be adjusted are part of a
+ * {@link SqlCall}, which is bound within a {@link SqlOperatorBinding}. The operands to be
+ * cast are identified by their indexes within the {@code operands} list of the {@link SqlCall}.
+ * The method performs a dynamic check to determine if an operand is a basic call to an array.
+ * If so, it casts each element within the array to the target type.
+ * Otherwise, it casts the operand itself to the target type.
+ *
+ *
Example usage: For an operation like {@code array_append(array(1,2), cast(2 as tinyint))},
+ * if targetType is double, this method would ensure that the elements of the
+ * first array and the second operand are cast to double.
+ *
+ * @param targetType The target {@link RelDataType} to which the operands should be cast.
+ * @param opBinding The {@link SqlOperatorBinding} context, which provides access to the
+ * {@link SqlCall} and its operands.
+ * @param indexes The indexes of the operands within the {@link SqlCall} that need to be
+ * adjusted to the target type.
+ * @throws NullPointerException if {@code targetType} is {@code null}.
+ */
+ public static void adjustTypeForArrayFunctions(
+ RelDataType targetType, SqlOperatorBinding opBinding, int... indexes) {
+ if (opBinding instanceof SqlCallBinding) {
+ requireNonNull(targetType, "array function target type");
+ SqlCall call = ((SqlCallBinding) opBinding).getCall();
+ List operands = call.getOperandList();
+ for (int idx : indexes) {
+ SqlNode operand = operands.get(idx);
+ if (operand instanceof SqlBasicCall
+ // not use SqlKind to compare because some other array function forms
+ // such as spark array, the SqlKind is other function.
+ // however, the name is same for those different array forms.
+ && "ARRAY".equals(((SqlBasicCall) operand).getOperator().getName())) {
+ call.setOperand(idx, castArrayElementTo(operand, targetType));
+ } else {
+ call.setOperand(idx, castTo(operand, targetType));
+ }
+ }
+ }
+ }
+
/**
* When the map key or value does not equal the map component key type or value type,
* make explicit casting.
@@ -1398,6 +1444,30 @@ private static SqlNode castTo(SqlNode node, RelDataType type) {
SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable()));
}
+ /**
+ * Creates a CAST operation that cast each element of the given {@link SqlNode} to the
+ * specified type. The {@link SqlNode} representing an array and a {@link RelDataType}
+ * representing the target type. This method uses the {@link SqlStdOperatorTable#CAST}
+ * operator to create a new {@link SqlCall} node representing a CAST operation.
+ * Each element of original 'node' is cast to the desired 'type', preserving the
+ * nullability of the 'type'.
+ *
+ * @param node the {@link SqlNode} the sqlnode representing an array
+ * @param type the target {@link RelDataType} the target type
+ * @return a new {@link SqlNode} representing the CAST operation
+ */
+ private static SqlNode castArrayElementTo(SqlNode node, RelDataType type) {
+ int i = 0;
+ for (SqlNode operand : ((SqlBasicCall) node).getOperandList()) {
+ SqlNode castedOperand =
+ SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO,
+ operand,
+ SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable()));
+ ((SqlBasicCall) node).setOperand(i++, castedOperand);
+ }
+ return node;
+ }
+
//~ Inner Classes ----------------------------------------------------------
/**
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index ebf23f290a3..01f308adfe0 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -6412,6 +6412,42 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
f.checkType("array_append(cast(null as integer array), 1)", "INTEGER NOT NULL ARRAY");
f.checkFails("^array_append(array[1, 2], true)^",
"INTEGER is not comparable to BOOLEAN", false);
+
+ // element cast to the biggest type
+ f.checkScalar("array_append(array(cast(1 as tinyint)), 2)", "[1, 2]",
+ "INTEGER NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(cast(1 as double)), cast(2 as float))", "[1.0, 2.0]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1), cast(2 as float))", "[1.0, 2.0]",
+ "FLOAT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1), cast(2 as double))", "[1.0, 2.0]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1), cast(2 as bigint))", "[1, 2]",
+ "BIGINT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1, 2), cast(3 as double))", "[1.0, 2.0, 3.0]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1, 2), cast(3 as float))", "[1.0, 2.0, 3.0]",
+ "FLOAT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1, 2), cast(3 as bigint))", "[1, 2, 3]",
+ "BIGINT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1, 2), cast(null as double))", "[1.0, 2.0, null]",
+ "DOUBLE ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1, 2), cast(null as float))", "[1.0, 2.0, null]",
+ "FLOAT ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1), cast(null as bigint))", "[1, null]",
+ "BIGINT ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1), cast(100 as decimal))", "[1, 100]",
+ "DECIMAL(19, 0) NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(1), 10e6)", "[1.0, 1.0E7]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_append(array(), cast(null as double))", "[null]",
+ "DOUBLE ARRAY NOT NULL");
+ f.checkScalar("array_append(array(), cast(null as float))", "[null]",
+ "FLOAT ARRAY NOT NULL");
+ f.checkScalar("array_append(array(), cast(null as tinyint))", "[null]",
+ "TINYINT ARRAY NOT NULL");
+ f.checkScalar("array_append(array(), cast(null as bigint))", "[null]",
+ "BIGINT ARRAY NOT NULL");
}
/** Tests {@code ARRAY_COMPACT} function from Spark. */
@@ -6648,7 +6684,7 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
"NULL ARRAY NOT NULL");
f.checkScalar("array_prepend(array(), null)", "[null]",
"UNKNOWN ARRAY NOT NULL");
- f.checkScalar("array_append(array(), 1)", "[1]",
+ f.checkScalar("array_prepend(array(), 1)", "[1]",
"INTEGER NOT NULL ARRAY NOT NULL");
f.checkScalar("array_prepend(array[array[1, 2]], array[3, 4])", "[[3, 4], [1, 2]]",
"INTEGER NOT NULL ARRAY NOT NULL ARRAY NOT NULL");
@@ -6658,6 +6694,40 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
f.checkType("array_prepend(cast(null as integer array), 1)", "INTEGER NOT NULL ARRAY");
f.checkFails("^array_prepend(array[1, 2], true)^",
"INTEGER is not comparable to BOOLEAN", false);
+
+ // element cast to the biggest type
+ f.checkScalar("array_prepend(array(1), cast(3 as float))", "[3.0, 1.0]",
+ "FLOAT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1), cast(3 as bigint))", "[3, 1]",
+ "BIGINT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(2), cast(3 as double))", "[3.0, 2.0]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1, 2), cast(3 as float))", "[3.0, 1.0, 2.0]",
+ "FLOAT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(2, 1), cast(3 as double))", "[3.0, 2.0, 1.0]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1, 2), cast(3 as tinyint))", "[3, 1, 2]",
+ "INTEGER NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1, 2), cast(3 as bigint))", "[3, 1, 2]",
+ "BIGINT NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1, 2), cast(null as double))", "[null, 1.0, 2.0]",
+ "DOUBLE ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1, 2), cast(null as float))", "[null, 1.0, 2.0]",
+ "FLOAT ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1), cast(null as bigint))", "[null, 1]",
+ "BIGINT ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1), cast(100 as decimal))", "[100, 1]",
+ "DECIMAL(19, 0) NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(1), 10e6)", "[1.0E7, 1.0]",
+ "DOUBLE NOT NULL ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(), cast(null as double))", "[null]",
+ "DOUBLE ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(), cast(null as float))", "[null]",
+ "FLOAT ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(), cast(null as tinyint))", "[null]",
+ "TINYINT ARRAY NOT NULL");
+ f.checkScalar("array_prepend(array(), cast(null as bigint))", "[null]",
+ "BIGINT ARRAY NOT NULL");
}
/** Tests {@code ARRAY_REMOVE} function from Spark. */
@@ -6944,6 +7014,22 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) {
"(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL");
f1.checkScalar("array_insert(array[map[1, 'a']], -1, map[2, 'b'])", "[{2=b}, {1=a}]",
"(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL");
+
+ // element cast to the biggest type
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as tinyint))",
+ "[1, 2, 4, 3]", "INTEGER NOT NULL ARRAY NOT NULL");
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as double))",
+ "[1.0, 2.0, 4.0, 3.0]", "DOUBLE NOT NULL ARRAY NOT NULL");
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as float))",
+ "[1.0, 2.0, 4.0, 3.0]", "FLOAT NOT NULL ARRAY NOT NULL");
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as bigint))",
+ "[1, 2, 4, 3]", "BIGINT NOT NULL ARRAY NOT NULL");
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as bigint))",
+ "[1, 2, null, 3]", "BIGINT ARRAY NOT NULL");
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as float))",
+ "[1.0, 2.0, null, 3.0]", "FLOAT ARRAY NOT NULL");
+ f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as tinyint))",
+ "[1, 2, null, 3]", "INTEGER ARRAY NOT NULL");
}
/** Tests {@code ARRAY_INTERSECT} function from Spark. */