diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 758b62fe31af..24dfff276973 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -1029,9 +1029,7 @@ private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBin public static final SqlFunction ARRAY_CONTAINS = SqlBasicFunction.create(SqlKind.ARRAY_CONTAINS, ReturnTypes.BOOLEAN_NULLABLE, - OperandTypes.and( - OperandTypes.NONNULL_NONNULL_NOT_CAST, - OperandTypes.ARRAY_ELEMENT)); + OperandTypes.ARRAY_ELEMENT_NON_NULL); /** The "ARRAY_DISTINCT(array)" function. */ @LibraryOperator(libraries = {SPARK}) @@ -1046,7 +1044,7 @@ private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBin SqlBasicFunction.create(SqlKind.ARRAY_EXCEPT, ReturnTypes.LEAST_RESTRICTIVE, OperandTypes.and( - OperandTypes.NONNULL_NONNULL_NOT_CAST, + OperandTypes.NONNULL_NONNULL, OperandTypes.SAME_SAME, OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY))); @@ -1056,7 +1054,7 @@ private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBin SqlBasicFunction.create(SqlKind.ARRAY_INTERSECT, ReturnTypes.LEAST_RESTRICTIVE, OperandTypes.and( - OperandTypes.NONNULL_NONNULL_NOT_CAST, + OperandTypes.NONNULL_NONNULL, OperandTypes.SAME_SAME, OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY))); diff --git a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java index 8291b1681d28..bed38a73e65a 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java @@ -21,6 +21,7 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlUtil; import com.google.common.collect.ImmutableList; @@ -31,11 +32,41 @@ * Parameter type-checking strategy where types must be Array and Array element type. */ public class ArrayElementOperandTypeChecker implements SqlOperandTypeChecker { + //~ Instance fields -------------------------------------------------------- + + private final boolean allowNullCheck; + private final boolean allowCast; + + //~ Constructors ----------------------------------------------------------- + + public ArrayElementOperandTypeChecker() { + this.allowNullCheck = false; + this.allowCast = false; + } + + public ArrayElementOperandTypeChecker(boolean allowNullCheck, boolean allowCast) { + this.allowNullCheck = allowNullCheck; + this.allowCast = allowCast; + } + //~ Methods ---------------------------------------------------------------- @Override public boolean checkOperandTypes( SqlCallBinding callBinding, boolean throwOnFailure) { + if (allowNullCheck) { + // no operand can be null for type-checking to succeed + for (SqlNode node : callBinding.operands()) { + if (SqlUtil.isNullLiteral(node, allowCast)) { + if (throwOnFailure) { + throw callBinding.getValidator().newValidationError(node, RESOURCE.nullIllegal()); + } else { + return false; + } + } + } + } + final SqlNode op0 = callBinding.operand(0); if (!OperandTypes.ARRAY.checkSingleOperandType( callBinding, diff --git a/core/src/main/java/org/apache/calcite/sql/type/NullOperandTypeChecker.java b/core/src/main/java/org/apache/calcite/sql/type/NullOperandTypeChecker.java index cdd7862b670f..7810f8d2914f 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/NullOperandTypeChecker.java +++ b/core/src/main/java/org/apache/calcite/sql/type/NullOperandTypeChecker.java @@ -23,8 +23,7 @@ import static org.apache.calcite.util.Static.RESOURCE; /** - * Parameter type-checking strategy type must not be a NULL (including NULL, - * CAST(NULL as ...) but not CAST(CAST(NULL as ...) AS ...)). + * Parameter type-checking strategy where all operand types must not be NULL. */ public class NullOperandTypeChecker extends SameOperandTypeChecker { //~ Instance fields -------------------------------------------------------- @@ -42,7 +41,7 @@ public NullOperandTypeChecker(final int nOperands, final boolean allowCast) { @Override public boolean checkOperandTypes(final SqlCallBinding callBinding, final boolean throwOnFailure) { - // all operands can't be null + // no operand can be null for type-checking to succeed for (SqlNode node : callBinding.operands()) { if (SqlUtil.isNullLiteral(node, allowCast)) { if (throwOnFailure) { diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java index 8a7520018f6f..93f0489a36af 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java @@ -527,6 +527,9 @@ public static SqlOperandTypeChecker variadic( public static final SqlOperandTypeChecker ARRAY_ELEMENT = new ArrayElementOperandTypeChecker(); + public static final SqlOperandTypeChecker ARRAY_ELEMENT_NON_NULL = + new ArrayElementOperandTypeChecker(true, false); + public static final SqlSingleOperandTypeChecker MAP_FROM_ENTRIES = new MapFromEntriesOperandTypeChecker(); @@ -543,9 +546,9 @@ public static SqlOperandTypeChecker variadic( new LiteralOperandTypeChecker(false); /** - * Operand type-checking strategy type must be a non-NULL value without cast. + * Operand type-checking strategy type must be a non-NULL value. */ - public static final SqlSingleOperandTypeChecker NONNULL_NONNULL_NOT_CAST = + public static final SqlSingleOperandTypeChecker NONNULL_NONNULL = new NullOperandTypeChecker(2, false); /** diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index fb71d9cc2548..2fb2cd6fe194 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -5716,20 +5716,13 @@ private static void checkIf(SqlOperatorFixture f) { // library (i.e. "fun=flink") we could add a function with Flink behavior. f.checkNull("array_contains(array[1, null], cast(null as integer))"); f.checkType("array_contains(array[1, null], cast(null as integer))", "BOOLEAN"); - f.checkFails("^array_contains(array[1, 2], true)^", "Cannot apply 'ARRAY_CONTAINS' to arguments" - + " of type 'ARRAY_CONTAINS\\(, \\)'\\. Supported form\\(s\\): " - + "'ARRAY_CONTAINS\\(, \\)'", false); + f.checkFails("^array_contains(array[1, 2], true)^", + "INTEGER is not comparable to BOOLEAN", false); // check null without cast - f.checkFails("^array_contains(array[1, null], null)^", "Cannot apply 'ARRAY_CONTAINS' to " - + "arguments of type 'ARRAY_CONTAINS\\(, \\)'\\. Supported form\\(s\\):" - + " 'ARRAY_CONTAINS\\(, \\)'", false); - f.checkFails("^array_contains(null, array[1, null])^", "Cannot apply 'ARRAY_CONTAINS' to " - + "arguments of type 'ARRAY_CONTAINS\\(, \\)'\\. Supported form\\(s\\):" - + " 'ARRAY_CONTAINS\\(, \\)'", false); - f.checkFails("^array_contains(array[1, 2], null)^", "Cannot apply 'ARRAY_CONTAINS' to " - + "arguments of type 'ARRAY_CONTAINS\\(, \\)'\\. Supported form\\(s\\):" - + " 'ARRAY_CONTAINS\\(, \\)'", false); + f.checkFails("array_contains(array[1, 2], ^null^)", "Illegal use of 'NULL'", false); + f.checkFails("array_contains(^null^, array[1, 2])", "Illegal use of 'NULL'", false); + f.checkFails("array_contains(^null^, null)", "Illegal use of 'NULL'", false); } /** Tests {@code ARRAY_DISTINCT} function from Spark. */ @@ -6017,6 +6010,10 @@ private static void checkIf(SqlOperatorFixture f) { "Cannot apply 'ARRAY_EXCEPT' to arguments of type 'ARRAY_EXCEPT\\(, " + "\\)'\\. Supported form\\(s\\): 'ARRAY_EXCEPT\\(, " + "\\)'", false); + f.checkFails("^array_except(null, null)^", + "Cannot apply 'ARRAY_EXCEPT' to arguments of type 'ARRAY_EXCEPT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_EXCEPT\\(, " + + "\\)'", false); } /** Tests {@code ARRAY_INTERSECT} function from Spark. */ @@ -6047,6 +6044,10 @@ private static void checkIf(SqlOperatorFixture f) { "Cannot apply 'ARRAY_INTERSECT' to arguments of type 'ARRAY_INTERSECT\\(, " + "\\)'\\. Supported form\\(s\\): 'ARRAY_INTERSECT\\(, " + "\\)'", false); + f.checkFails("^array_intersect(null, null)^", + "Cannot apply 'ARRAY_INTERSECT' to arguments of type 'ARRAY_INTERSECT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_INTERSECT\\(, " + + "\\)'", false); } /** Tests {@code ARRAY_UNION} function from Spark. */