From e4c1fad1fbd5fcc726da38a667bea8d481a3fe5f Mon Sep 17 00:00:00 2001 From: Ran Tao Date: Mon, 7 Aug 2023 21:11:28 +0800 Subject: [PATCH] [CALCITE-5893] Wrong NULL operand behavior of ARRAY_CONTAINS/ARRAY_EXCEPT/ARRAY_INTERSECT In Spark Library --- .../calcite/sql/fun/SqlLibraryOperators.java | 4 +- .../type/ArrayElementOperandTypeChecker.java | 31 ++++++++++ .../sql/type/NotNullOperandTypeChecker.java | 56 +++++++++++++++++++ .../apache/calcite/sql/type/OperandTypes.java | 9 +++ .../apache/calcite/test/SqlOperatorTest.java | 33 +++++++++++ 5 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 core/src/main/java/org/apache/calcite/sql/type/NotNullOperandTypeChecker.java diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index 86a72d2652f..8549e1bb921 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -1240,7 +1240,7 @@ private static RelDataType arrayCompactReturnType(SqlOperatorBinding opBinding) public static final SqlFunction ARRAY_CONTAINS = SqlBasicFunction.create(SqlKind.ARRAY_CONTAINS, ReturnTypes.BOOLEAN_NULLABLE, - OperandTypes.ARRAY_ELEMENT); + OperandTypes.ARRAY_ELEMENT_NONNULL); /** The "ARRAY_DISTINCT(array)" function. */ @LibraryOperator(libraries = {SPARK}) @@ -1255,6 +1255,7 @@ private static RelDataType arrayCompactReturnType(SqlOperatorBinding opBinding) SqlBasicFunction.create(SqlKind.ARRAY_EXCEPT, ReturnTypes.LEAST_RESTRICTIVE, OperandTypes.and( + OperandTypes.NONNULL_NONNULL, OperandTypes.SAME_SAME, OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY))); @@ -1287,6 +1288,7 @@ private static RelDataType arrayInsertReturnType(SqlOperatorBinding opBinding) { SqlBasicFunction.create(SqlKind.ARRAY_INTERSECT, ReturnTypes.LEAST_RESTRICTIVE, OperandTypes.and( + OperandTypes.NONNULL_NONNULL, OperandTypes.SAME_SAME, OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ARRAY))); diff --git a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java index 8291b1681d2..bed38a73e65 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java @@ -21,6 +21,7 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperandCountRange; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlUtil; import com.google.common.collect.ImmutableList; @@ -31,11 +32,41 @@ * Parameter type-checking strategy where types must be Array and Array element type. */ public class ArrayElementOperandTypeChecker implements SqlOperandTypeChecker { + //~ Instance fields -------------------------------------------------------- + + private final boolean allowNullCheck; + private final boolean allowCast; + + //~ Constructors ----------------------------------------------------------- + + public ArrayElementOperandTypeChecker() { + this.allowNullCheck = false; + this.allowCast = false; + } + + public ArrayElementOperandTypeChecker(boolean allowNullCheck, boolean allowCast) { + this.allowNullCheck = allowNullCheck; + this.allowCast = allowCast; + } + //~ Methods ---------------------------------------------------------------- @Override public boolean checkOperandTypes( SqlCallBinding callBinding, boolean throwOnFailure) { + if (allowNullCheck) { + // no operand can be null for type-checking to succeed + for (SqlNode node : callBinding.operands()) { + if (SqlUtil.isNullLiteral(node, allowCast)) { + if (throwOnFailure) { + throw callBinding.getValidator().newValidationError(node, RESOURCE.nullIllegal()); + } else { + return false; + } + } + } + } + final SqlNode op0 = callBinding.operand(0); if (!OperandTypes.ARRAY.checkSingleOperandType( callBinding, diff --git a/core/src/main/java/org/apache/calcite/sql/type/NotNullOperandTypeChecker.java b/core/src/main/java/org/apache/calcite/sql/type/NotNullOperandTypeChecker.java new file mode 100644 index 00000000000..e6deba1163b --- /dev/null +++ b/core/src/main/java/org/apache/calcite/sql/type/NotNullOperandTypeChecker.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.sql.type; + +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlUtil; + +import static org.apache.calcite.util.Static.RESOURCE; + +/** + * Parameter type-checking strategy where all operand types must not be NULL. + */ +public class NotNullOperandTypeChecker extends SameOperandTypeChecker { + //~ Instance fields -------------------------------------------------------- + + private final boolean allowCast; + + //~ Constructors ----------------------------------------------------------- + + public NotNullOperandTypeChecker(final int nOperands, final boolean allowCast) { + super(nOperands); + this.allowCast = allowCast; + } + + //~ Methods ---------------------------------------------------------------- + + @Override public boolean checkOperandTypes(final SqlCallBinding callBinding, + final boolean throwOnFailure) { + // no operand can be null for type-checking to succeed + for (SqlNode node : callBinding.operands()) { + if (SqlUtil.isNullLiteral(node, allowCast)) { + if (throwOnFailure) { + throw callBinding.getValidator().newValidationError(node, RESOURCE.nullIllegal()); + } else { + return false; + } + } + } + return true; + } +} diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java index 4cdac722838..18767f5a7d1 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java @@ -614,6 +614,9 @@ public static SqlOperandTypeChecker variadic( public static final SqlOperandTypeChecker ARRAY_ELEMENT = new ArrayElementOperandTypeChecker(); + public static final SqlOperandTypeChecker ARRAY_ELEMENT_NONNULL = + new ArrayElementOperandTypeChecker(true, false); + public static final SqlOperandTypeChecker ARRAY_INSERT = new ArrayInsertOperandTypeChecker(); @@ -638,6 +641,12 @@ public static SqlOperandTypeChecker variadic( public static final SqlSingleOperandTypeChecker LITERAL = new LiteralOperandTypeChecker(false); + /** + * Operand type-checking strategy where all types must be non-NULL value. + */ + public static final SqlSingleOperandTypeChecker NONNULL_NONNULL = + new NotNullOperandTypeChecker(2, false); + /** * Operand type-checking strategy type must be a boolean non-NULL literal. */ diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 0ca4b26f9e2..8e15b0af049 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -6419,6 +6419,11 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkType("array_contains(array[1, null], cast(null as integer))", "BOOLEAN"); f.checkFails("^array_contains(array[1, 2], true)^", "INTEGER is not comparable to BOOLEAN", false); + + // check null without cast + f.checkFails("array_contains(array[1, 2], ^null^)", "Illegal use of 'NULL'", false); + f.checkFails("array_contains(^null^, array[1, 2])", "Illegal use of 'NULL'", false); + f.checkFails("array_contains(^null^, null)", "Illegal use of 'NULL'", false); } /** Tests {@code ARRAY_DISTINCT} function from Spark. */ @@ -6779,6 +6784,20 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkNull("array_except(cast(null as integer array), array[1])"); f.checkNull("array_except(array[1], cast(null as integer array))"); f.checkNull("array_except(cast(null as integer array), cast(null as integer array))"); + + // check null without cast + f.checkFails("^array_except(array[1, 2], null)^", + "Cannot apply 'ARRAY_EXCEPT' to arguments of type 'ARRAY_EXCEPT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_EXCEPT\\(, " + + "\\)'", false); + f.checkFails("^array_except(null, array[1, 2])^", + "Cannot apply 'ARRAY_EXCEPT' to arguments of type 'ARRAY_EXCEPT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_EXCEPT\\(, " + + "\\)'", false); + f.checkFails("^array_except(null, null)^", + "Cannot apply 'ARRAY_EXCEPT' to arguments of type 'ARRAY_EXCEPT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_EXCEPT\\(, " + + "\\)'", false); } /** Tests {@code ARRAY_INSERT} function from Spark. */ @@ -6874,6 +6893,20 @@ void checkRegexpExtract(SqlOperatorFixture f0, FunctionAlias functionAlias) { f.checkNull("array_intersect(cast(null as integer array), array[1])"); f.checkNull("array_intersect(array[1], cast(null as integer array))"); f.checkNull("array_intersect(cast(null as integer array), cast(null as integer array))"); + + // check null without cast + f.checkFails("^array_intersect(array[1, 2], null)^", + "Cannot apply 'ARRAY_INTERSECT' to arguments of type 'ARRAY_INTERSECT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_INTERSECT\\(, " + + "\\)'", false); + f.checkFails("^array_intersect(null, array[1, 2])^", + "Cannot apply 'ARRAY_INTERSECT' to arguments of type 'ARRAY_INTERSECT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_INTERSECT\\(, " + + "\\)'", false); + f.checkFails("^array_intersect(null, null)^", + "Cannot apply 'ARRAY_INTERSECT' to arguments of type 'ARRAY_INTERSECT\\(, " + + "\\)'\\. Supported form\\(s\\): 'ARRAY_INTERSECT\\(, " + + "\\)'", false); } /** Tests {@code ARRAY_UNION} function from Spark. */