From 12cd99c8b1397b643b6d842f58d0331f54b7cb09 Mon Sep 17 00:00:00 2001 From: Ran Tao Date: Fri, 1 Mar 2024 16:51:37 +0800 Subject: [PATCH] Refine the javadoc and NPE and some other logic --- .../calcite/sql/fun/SqlLibraryOperators.java | 32 +++++--- .../sql/validate/SqlValidatorUtil.java | 76 ++++++++++++------- .../apache/calcite/test/SqlOperatorTest.java | 11 ++- 3 files changed, 75 insertions(+), 44 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 06a8dab24f8..ac14eabfc5d 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -1181,20 +1181,26 @@ private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBin final RelDataType arrayType = opBinding.collectOperandTypes().get(0); final RelDataType componentType = arrayType.getComponentType(); final RelDataType elementType = opBinding.collectOperandTypes().get(1); + requireNonNull(componentType, () -> "componentType of " + arrayType); + RelDataType type = opBinding.getTypeFactory().leastRestrictive( ImmutableList.of(componentType, elementType)); + requireNonNull(type, "inferred array element type"); + if (elementType.isNullable()) { type = opBinding.getTypeFactory().createTypeWithNullability(type, true); } - if (!componentType.isNullable() && !componentType.equalsSansFieldNames(elementType)) { - SqlValidatorUtil. - adjustTypeForArrayFunctionConstructor(type, opBinding, 0); + + // make explicit CAST for array elements and inserted element to the biggest type + // if array component type not equals to inserted element type + if (!componentType.equalsSansFieldNames(elementType)) { + // 0, 1 is the operand index to be CAST + // For array_append/array_prepend, 0 is the array arg and 1 is the inserted element SqlValidatorUtil. - adjustTypeForArrayFunctionConstructor(type, opBinding, 1); + adjustTypeForArrayOperationFunction(type, opBinding, 0, 1); } - requireNonNull(type, "inferred array element type"); return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, arrayType.isNullable()); } @@ -1270,23 +1276,27 @@ private static RelDataType arrayInsertReturnType(SqlOperatorBinding opBinding) { final RelDataType arrayType = opBinding.collectOperandTypes().get(0); final RelDataType componentType = arrayType.getComponentType(); final RelDataType elementType = opBinding.collectOperandTypes().get(2); + requireNonNull(componentType, () -> "componentType of " + arrayType); + // we don't need to do leastRestrictive on componentType and elementType, // because in operand checker we limit the elementType must equals array component type. // So we use componentType directly. RelDataType type = opBinding.getTypeFactory().leastRestrictive( ImmutableList.of(componentType, elementType)); + requireNonNull(type, "inferred array element type"); + if (elementType.isNullable()) { type = opBinding.getTypeFactory().createTypeWithNullability(type, true); } - - if (!componentType.isNullable() && !componentType.equalsSansFieldNames(elementType)) { - SqlValidatorUtil. - adjustTypeForArrayFunctionConstructor(type, opBinding, 0); + // make explicit CAST for array elements and inserted element to the biggest type + // if array component type not equals to inserted element type + if (!componentType.equalsSansFieldNames(elementType)) { + // 0, 2 is the operand index to be CAST + // For array_insert, 0 is the array arg and 2 is the inserted element SqlValidatorUtil. - adjustTypeForArrayFunctionConstructor(type, opBinding, 2); + adjustTypeForArrayOperationFunction(type, opBinding, 0, 2); } - requireNonNull(type, "inferred array element type"); return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, arrayType.isNullable()); } diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java index 9a36da34c3d..19984b1774c 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java @@ -1328,32 +1328,48 @@ public static void adjustTypeForArrayConstructor( } /** - * When the array element does not equal the biggest type, make explicit casting. + * Adjusts the types of specified operands in an array operation to match a given target type. + * This is particularly useful in the context of SQL operations involving array functions, + * where it's necessary to ensure that all operands have consistent types for the operation + * to be valid. * - * @param componentType derived array component type - * @param opBinding description of call - * @param index index of opBinding + *

This method operates on the assumption that the operands to be adjusted are part of a + * {@link SqlCall}, which is bound within a {@link SqlOperatorBinding}. The operands to be + * cast are identified by their indexes within the {@code operands} list of the {@link SqlCall}. + * The method performs a dynamic check to determine if an operand is a basic call to an array. + * If so, it casts each element within the array to the target type. + * Otherwise, it casts the operand itself to the target type. + * + *

Example usage: For an operation like {@code array_append(array(1,2), cast(2 as tinyint))}, + * if targetType is double, this method would ensure that the elements of the + * first array and the second operand are cast to double. + * + * @param targetType The target {@link RelDataType} to which the operands should be cast. + * @param opBinding The {@link SqlOperatorBinding} context, which provides access to the + * {@link SqlCall} and its operands. + * @param indexes The indexes of the operands within the {@link SqlCall} that need to be + * adjusted to the target type. + * @throws NullPointerException if {@code targetType} is {@code null}. */ - public static void adjustTypeForArrayFunctionConstructor( - RelDataType componentType, SqlOperatorBinding opBinding, int index) { + public static void adjustTypeForArrayOperationFunction( + RelDataType targetType, SqlOperatorBinding opBinding, int... indexes) { if (opBinding instanceof SqlCallBinding) { - requireNonNull(componentType, "array component type"); - adjustTypeForMultisetConstructor( - componentType, (SqlCallBinding) opBinding, index); - } - } - - private static void adjustTypeForMultisetConstructor( - RelDataType evenType, SqlCallBinding sqlCallBinding, int index) { - SqlCall call = sqlCallBinding.getCall(); - List operands = call.getOperandList(); - RelDataType elementType = evenType; - if (index == 0) { - call.setOperand(index, arrayToCast(operands.get(index), elementType)); - } else { - call.setOperand(index, castTo(operands.get(index), elementType)); + requireNonNull(targetType, "array function target type"); + SqlCall call = ((SqlCallBinding) opBinding).getCall(); + List operands = call.getOperandList(); + for (int idx : indexes) { + SqlNode operand = operands.get(idx); + if (operand instanceof SqlBasicCall + // not use SqlKind to compare because some other array function forms + // such as spark array, the SqlKind is other function. + // however, the name is same for those different array forms. + && "ARRAY".equals(((SqlBasicCall) operand).getOperator().getName())) { + call.setOperand(idx, castArrayElementTo(operand, targetType)); + } else { + call.setOperand(idx, castTo(operand, targetType)); + } + } } - } /** @@ -1429,16 +1445,18 @@ private static SqlNode castTo(SqlNode node, RelDataType type) { } /** - * Creates a CAST operation that converts the given Array in {@link SqlNode} to the specified {@link RelDataType}. - * This method uses the {@link SqlStdOperatorTable#CAST} operator to create a new {@link SqlCall} - * node representing a CAST operation. The original 'node' is cast to the desired 'type', - * preserving the nullability of the 'type'. + * Creates a CAST operation that cast each element of the given {@link SqlNode} to the + * specified type. The {@link SqlNode} representing an array and a {@link RelDataType} + * representing the target type. This method uses the {@link SqlStdOperatorTable#CAST} + * operator to create a new {@link SqlCall} node representing a CAST operation. + * Each element of original 'node' is cast to the desired 'type', preserving the + * nullability of the 'type'. * - * @param node the {@link SqlNode} which is to be cast - * @param type the target {@link RelDataType} to which 'node' should be cast + * @param node the {@link SqlNode} the sqlnode representing an array + * @param type the target {@link RelDataType} the target type * @return a new {@link SqlNode} representing the CAST operation */ - private static SqlNode arrayToCast(SqlNode node, RelDataType type) { + private static SqlNode castArrayElementTo(SqlNode node, RelDataType type) { int i = 0; for (SqlNode operand : ((SqlBasicCall) node).getOperandList()) { SqlNode castedOperand = diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 9a89d71e008..2b040b2ac3f 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -6341,8 +6341,9 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkType("array_append(cast(null as integer array), 1)", "INTEGER NOT NULL ARRAY"); f.checkFails("^array_append(array[1, 2], true)^", "INTEGER is not comparable to BOOLEAN", false); - // cast biggest type - f.checkScalar("array_append(array(1), cast(2 as tinyint))", "[1, 2]", + + // element cast to the biggest type + f.checkScalar("array_append(array(cast(1 as tinyint)), 2)", "[1, 2]", "INTEGER NOT NULL ARRAY NOT NULL"); f.checkScalar("array_append(array(cast(1 as double)), cast(2 as float))", "[1.0, 2.0]", "DOUBLE NOT NULL ARRAY NOT NULL"); @@ -6617,7 +6618,8 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkType("array_prepend(cast(null as integer array), 1)", "INTEGER NOT NULL ARRAY"); f.checkFails("^array_prepend(array[1, 2], true)^", "INTEGER is not comparable to BOOLEAN", false); - // cast biggest type + + // element cast to the biggest type f.checkScalar("array_prepend(array(1), cast(3 as float))", "[3.0, 1.0]", "FLOAT NOT NULL ARRAY NOT NULL"); f.checkScalar("array_prepend(array(1), cast(3 as bigint))", "[3, 1]", @@ -6922,7 +6924,8 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { "(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL"); f1.checkScalar("array_insert(array[map[1, 'a']], -1, map[2, 'b'])", "[{2=b}, {1=a}]", "(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL"); - // cast biggest type + + // element cast to the biggest type f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as tinyint))", "[1, 2, 4, 3]", "INTEGER NOT NULL ARRAY NOT NULL"); f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as double))",