From f0dc2b0aea46b1fd3f37e0cc126edaf82ade2344 Mon Sep 17 00:00:00 2001 From: caicancai <2356672992@qq.com> Date: Mon, 26 Feb 2024 23:14:21 +0800 Subject: [PATCH] [CALCITE-5976] Function ARRAY_PREPEND/ARRAY_APPEND/ARRAY_INSERT gives exception when inserted element type not equals array component type Co-authored-by: Ran Tao --- .../calcite/sql/fun/SqlLibraryOperators.java | 42 ++++++++- .../sql/validate/SqlValidatorUtil.java | 70 +++++++++++++++ .../apache/calcite/test/SqlOperatorTest.java | 88 ++++++++++++++++++- 3 files changed, 196 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 38b5b48f77b..0c700d9de32 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -1199,13 +1199,31 @@ private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBin return arrayType; } final RelDataType elementType = opBinding.collectOperandTypes().get(1); + requireNonNull(componentType, () -> "componentType of " + arrayType); + RelDataType type = opBinding.getTypeFactory().leastRestrictive( ImmutableList.of(componentType, elementType)); + requireNonNull(type, "inferred array element type"); + if (elementType.isNullable()) { type = opBinding.getTypeFactory().createTypeWithNullability(type, true); } - requireNonNull(type, "inferred array element type"); + + // make explicit CAST for array elements and inserted element to the biggest type + // if array component type not equals to inserted element type + if (!componentType.equalsSansFieldNames(elementType)) { + // 0, 1 is the operand index to be CAST + // For array_append/array_prepend, 0 is the array arg and 1 is the inserted element + if (componentType.equalsSansFieldNames(type)) { + SqlValidatorUtil. + adjustTypeForArrayFunctions(type, opBinding, 1); + } else { + SqlValidatorUtil. + adjustTypeForArrayFunctions(type, opBinding, 0); + } + } + return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, arrayType.isNullable()); } @@ -1282,14 +1300,32 @@ private static RelDataType arrayInsertReturnType(SqlOperatorBinding opBinding) { final RelDataType arrayType = opBinding.collectOperandTypes().get(0); final RelDataType componentType = arrayType.getComponentType(); final RelDataType elementType = opBinding.collectOperandTypes().get(2); + requireNonNull(componentType, () -> "componentType of " + arrayType); + // we don't need to do leastRestrictive on componentType and elementType, // because in operand checker we limit the elementType must equals array component type. // So we use componentType directly. - RelDataType type = componentType; + RelDataType type = + opBinding.getTypeFactory().leastRestrictive( + ImmutableList.of(componentType, elementType)); + requireNonNull(type, "inferred array element type"); + if (elementType.isNullable()) { type = opBinding.getTypeFactory().createTypeWithNullability(type, true); } - requireNonNull(type, "inferred array element type"); + // make explicit CAST for array elements and inserted element to the biggest type + // if array component type not equals to inserted element type + if (!componentType.equalsSansFieldNames(elementType)) { + // 0, 2 is the operand index to be CAST + // For array_insert, 0 is the array arg and 2 is the inserted element + if (componentType.equalsSansFieldNames(type)) { + SqlValidatorUtil. + adjustTypeForArrayFunctions(type, opBinding, 2); + } else { + SqlValidatorUtil. + adjustTypeForArrayFunctions(type, opBinding, 0); + } + } return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, arrayType.isNullable()); } diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java index fd18c2d7b18..9aeec2da20f 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java @@ -37,6 +37,7 @@ import org.apache.calcite.schema.Table; import org.apache.calcite.schema.impl.AbstractSchema; import org.apache.calcite.schema.impl.AbstractTable; +import org.apache.calcite.sql.SqlBasicCall; import org.apache.calcite.sql.SqlCall; import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlDataTypeSpec; @@ -1326,6 +1327,51 @@ public static void adjustTypeForArrayConstructor( } } + /** + * Adjusts the types of specified operands in an array operation to match a given target type. + * This is particularly useful in the context of SQL operations involving array functions, + * where it's necessary to ensure that all operands have consistent types for the operation + * to be valid. + * + *

This method operates on the assumption that the operands to be adjusted are part of a + * {@link SqlCall}, which is bound within a {@link SqlOperatorBinding}. The operands to be + * cast are identified by their indexes within the {@code operands} list of the {@link SqlCall}. + * The method performs a dynamic check to determine if an operand is a basic call to an array. + * If so, it casts each element within the array to the target type. + * Otherwise, it casts the operand itself to the target type. + * + *

Example usage: For an operation like {@code array_append(array(1,2), cast(2 as tinyint))}, + * if targetType is double, this method would ensure that the elements of the + * first array and the second operand are cast to double. + * + * @param targetType The target {@link RelDataType} to which the operands should be cast. + * @param opBinding The {@link SqlOperatorBinding} context, which provides access to the + * {@link SqlCall} and its operands. + * @param indexes The indexes of the operands within the {@link SqlCall} that need to be + * adjusted to the target type. + * @throws NullPointerException if {@code targetType} is {@code null}. + */ + public static void adjustTypeForArrayFunctions( + RelDataType targetType, SqlOperatorBinding opBinding, int... indexes) { + if (opBinding instanceof SqlCallBinding) { + requireNonNull(targetType, "array function target type"); + SqlCall call = ((SqlCallBinding) opBinding).getCall(); + List operands = call.getOperandList(); + for (int idx : indexes) { + SqlNode operand = operands.get(idx); + if (operand instanceof SqlBasicCall + // not use SqlKind to compare because some other array function forms + // such as spark array, the SqlKind is other function. + // however, the name is same for those different array forms. + && "ARRAY".equals(((SqlBasicCall) operand).getOperator().getName())) { + call.setOperand(idx, castArrayElementTo(operand, targetType)); + } else { + call.setOperand(idx, castTo(operand, targetType)); + } + } + } + } + /** * When the map key or value does not equal the map component key type or value type, * make explicit casting. @@ -1398,6 +1444,30 @@ private static SqlNode castTo(SqlNode node, RelDataType type) { SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable())); } + /** + * Creates a CAST operation that cast each element of the given {@link SqlNode} to the + * specified type. The {@link SqlNode} representing an array and a {@link RelDataType} + * representing the target type. This method uses the {@link SqlStdOperatorTable#CAST} + * operator to create a new {@link SqlCall} node representing a CAST operation. + * Each element of original 'node' is cast to the desired 'type', preserving the + * nullability of the 'type'. + * + * @param node the {@link SqlNode} the sqlnode representing an array + * @param type the target {@link RelDataType} the target type + * @return a new {@link SqlNode} representing the CAST operation + */ + private static SqlNode castArrayElementTo(SqlNode node, RelDataType type) { + int i = 0; + for (SqlNode operand : ((SqlBasicCall) node).getOperandList()) { + SqlNode castedOperand = + SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO, + operand, + SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable())); + ((SqlBasicCall) node).setOperand(i++, castedOperand); + } + return node; + } + //~ Inner Classes ---------------------------------------------------------- /** diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index ebf23f290a3..01f308adfe0 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -6412,6 +6412,42 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkType("array_append(cast(null as integer array), 1)", "INTEGER NOT NULL ARRAY"); f.checkFails("^array_append(array[1, 2], true)^", "INTEGER is not comparable to BOOLEAN", false); + + // element cast to the biggest type + f.checkScalar("array_append(array(cast(1 as tinyint)), 2)", "[1, 2]", + "INTEGER NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(cast(1 as double)), cast(2 as float))", "[1.0, 2.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1), cast(2 as float))", "[1.0, 2.0]", + "FLOAT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1), cast(2 as double))", "[1.0, 2.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1), cast(2 as bigint))", "[1, 2]", + "BIGINT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1, 2), cast(3 as double))", "[1.0, 2.0, 3.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1, 2), cast(3 as float))", "[1.0, 2.0, 3.0]", + "FLOAT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1, 2), cast(3 as bigint))", "[1, 2, 3]", + "BIGINT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1, 2), cast(null as double))", "[1.0, 2.0, null]", + "DOUBLE ARRAY NOT NULL"); + f.checkScalar("array_append(array(1, 2), cast(null as float))", "[1.0, 2.0, null]", + "FLOAT ARRAY NOT NULL"); + f.checkScalar("array_append(array(1), cast(null as bigint))", "[1, null]", + "BIGINT ARRAY NOT NULL"); + f.checkScalar("array_append(array(1), cast(100 as decimal))", "[1, 100]", + "DECIMAL(19, 0) NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(1), 10e6)", "[1.0, 1.0E7]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_append(array(), cast(null as double))", "[null]", + "DOUBLE ARRAY NOT NULL"); + f.checkScalar("array_append(array(), cast(null as float))", "[null]", + "FLOAT ARRAY NOT NULL"); + f.checkScalar("array_append(array(), cast(null as tinyint))", "[null]", + "TINYINT ARRAY NOT NULL"); + f.checkScalar("array_append(array(), cast(null as bigint))", "[null]", + "BIGINT ARRAY NOT NULL"); } /** Tests {@code ARRAY_COMPACT} function from Spark. */ @@ -6648,7 +6684,7 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { "NULL ARRAY NOT NULL"); f.checkScalar("array_prepend(array(), null)", "[null]", "UNKNOWN ARRAY NOT NULL"); - f.checkScalar("array_append(array(), 1)", "[1]", + f.checkScalar("array_prepend(array(), 1)", "[1]", "INTEGER NOT NULL ARRAY NOT NULL"); f.checkScalar("array_prepend(array[array[1, 2]], array[3, 4])", "[[3, 4], [1, 2]]", "INTEGER NOT NULL ARRAY NOT NULL ARRAY NOT NULL"); @@ -6658,6 +6694,40 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkType("array_prepend(cast(null as integer array), 1)", "INTEGER NOT NULL ARRAY"); f.checkFails("^array_prepend(array[1, 2], true)^", "INTEGER is not comparable to BOOLEAN", false); + + // element cast to the biggest type + f.checkScalar("array_prepend(array(1), cast(3 as float))", "[3.0, 1.0]", + "FLOAT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1), cast(3 as bigint))", "[3, 1]", + "BIGINT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(2), cast(3 as double))", "[3.0, 2.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1, 2), cast(3 as float))", "[3.0, 1.0, 2.0]", + "FLOAT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(2, 1), cast(3 as double))", "[3.0, 2.0, 1.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1, 2), cast(3 as tinyint))", "[3, 1, 2]", + "INTEGER NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1, 2), cast(3 as bigint))", "[3, 1, 2]", + "BIGINT NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1, 2), cast(null as double))", "[null, 1.0, 2.0]", + "DOUBLE ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1, 2), cast(null as float))", "[null, 1.0, 2.0]", + "FLOAT ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1), cast(null as bigint))", "[null, 1]", + "BIGINT ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1), cast(100 as decimal))", "[100, 1]", + "DECIMAL(19, 0) NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(1), 10e6)", "[1.0E7, 1.0]", + "DOUBLE NOT NULL ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(), cast(null as double))", "[null]", + "DOUBLE ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(), cast(null as float))", "[null]", + "FLOAT ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(), cast(null as tinyint))", "[null]", + "TINYINT ARRAY NOT NULL"); + f.checkScalar("array_prepend(array(), cast(null as bigint))", "[null]", + "BIGINT ARRAY NOT NULL"); } /** Tests {@code ARRAY_REMOVE} function from Spark. */ @@ -6944,6 +7014,22 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { "(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL"); f1.checkScalar("array_insert(array[map[1, 'a']], -1, map[2, 'b'])", "[{2=b}, {1=a}]", "(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL"); + + // element cast to the biggest type + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as tinyint))", + "[1, 2, 4, 3]", "INTEGER NOT NULL ARRAY NOT NULL"); + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as double))", + "[1.0, 2.0, 4.0, 3.0]", "DOUBLE NOT NULL ARRAY NOT NULL"); + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as float))", + "[1.0, 2.0, 4.0, 3.0]", "FLOAT NOT NULL ARRAY NOT NULL"); + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as bigint))", + "[1, 2, 4, 3]", "BIGINT NOT NULL ARRAY NOT NULL"); + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as bigint))", + "[1, 2, null, 3]", "BIGINT ARRAY NOT NULL"); + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as float))", + "[1.0, 2.0, null, 3.0]", "FLOAT ARRAY NOT NULL"); + f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as tinyint))", + "[1, 2, null, 3]", "INTEGER ARRAY NOT NULL"); } /** Tests {@code ARRAY_INTERSECT} function from Spark. */