From eb92f96a747b3ba1301c7ee8499392260e65faa0 Mon Sep 17 00:00:00 2001 From: caicancai <2356672992@qq.com> Date: Fri, 3 May 2024 21:16:13 +0800 Subject: [PATCH] [CALCITE-6397] Add NVL2 function (enabled in Oracle, Spark library) --- .../java/org/apache/calcite/sql/SqlKind.java | 7 +++- .../calcite/sql/fun/SqlLibraryOperators.java | 8 ++++ .../apache/calcite/sql/type/OperandTypes.java | 13 ++++++ .../apache/calcite/sql/type/ReturnTypes.java | 10 +++++ .../sql2rel/StandardConvertletTable.java | 24 +++++++++++ site/_docs/reference.md | 1 + .../apache/calcite/test/SqlOperatorTest.java | 41 +++++++++++++++++++ 7 files changed, 102 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java index ae44f495a98..30fc858e288 100644 --- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java +++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java @@ -421,10 +421,13 @@ public enum SqlKind { /** {@code DECODE} function (Oracle). */ DECODE, - /** {@code NVL} function (Oracle). */ + /** {@code NVL} function (Oracle, Spark). */ NVL, - /** {@code GREATEST} function (Oracle). */ + /** {@code NVL2} function (Oracle, Spark). */ + NVL2, + + /** {@code GREATEST} function (Oracle, Spark). */ GREATEST, /** The two-argument {@code CONCAT} function (Oracle). */ diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java index bf75070e484..efff32c25d6 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java @@ -271,6 +271,14 @@ private static SqlCall transformConvert(SqlValidator validator, SqlCall call) { .andThen(SqlTypeTransforms.TO_NULLABLE_ALL), OperandTypes.SAME_SAME); + /** The "NVL2(value, value, value)" function. */ + @LibraryOperator(libraries = {ORACLE, SPARK}) + public static final SqlBasicFunction NVL2 = + SqlBasicFunction.create(SqlKind.NVL2, + ReturnTypes.NVL2_RESTRICTIVE + .andThen(SqlTypeTransforms.TO_NULLABLE_ALL), + OperandTypes.SECOND_THIRD_SAME); + /** The "IFNULL(value, value)" function. */ @LibraryOperator(libraries = {BIG_QUERY, SPARK}) public static final SqlFunction IFNULL = NVL.withName("IFNULL"); diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java index c7a85cd66e6..08b63011542 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java @@ -986,6 +986,19 @@ public static SqlSingleOperandTypeChecker same(int operandCount, public static final SqlSingleOperandTypeChecker ANY_ANY = family(SqlTypeFamily.ANY, SqlTypeFamily.ANY); + + /** + * Operand type-checking strategy where the second and third operands must be comparable. + * This is used when the operator has three operands and only the + * second and third operands need to be comparable. + */ + public static final SqlSingleOperandTypeChecker SECOND_THIRD_SAME = + new SameOperandTypeChecker(3) { + @Override protected List getOperandList(int operandCount) { + // Only check the second and third operands + return ImmutableList.of(1, 2); + } + }; public static final SqlSingleOperandTypeChecker ANY_IGNORE = family(SqlTypeFamily.ANY, SqlTypeFamily.IGNORE); public static final SqlSingleOperandTypeChecker IGNORE_ANY = diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index eb7f9e44717..511fe2eff7b 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -35,6 +35,7 @@ import org.apache.calcite.util.Util; import java.util.AbstractList; +import java.util.Arrays; import java.util.List; import java.util.function.UnaryOperator; @@ -551,6 +552,15 @@ public static SqlCall stripSeparator(SqlCall call) { opBinding -> opBinding.getTypeFactory().leastRestrictive( opBinding.collectOperandTypes()); + /** + * Type-inference strategy for NVL2 function. It returns the least restrictive type + * between the second and third operands. + */ + public static final SqlReturnTypeInference NVL2_RESTRICTIVE = opBinding -> { + return opBinding.getTypeFactory().leastRestrictive( + Arrays.asList(opBinding.getOperandType(1), opBinding.getOperandType(2))); + }; + /** * Type-inference strategy that returns the type of the first operand, unless it * is an integer type, in which case the return type is DOUBLE. diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java index f3360d2e31c..82e1ceeaeae 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java @@ -232,6 +232,7 @@ private StandardConvertletTable() { registerOp(operator, StandardConvertletTable::convertQuantifyOperator)); registerOp(SqlLibraryOperators.NVL, StandardConvertletTable::convertNvl); + registerOp(SqlLibraryOperators.NVL2, StandardConvertletTable::convertNvl2); registerOp(SqlLibraryOperators.DECODE, StandardConvertletTable::convertDecode); registerOp(SqlLibraryOperators.IF, StandardConvertletTable::convertIf); @@ -421,6 +422,29 @@ private static RexNode convertNvl(SqlRexContext cx, SqlCall call) { operand1))); } + /** Converts a call to the {@code NVL2} function. */ + private static RexNode convertNvl2(SqlRexContext cx, SqlCall call) { + final RexBuilder rexBuilder = cx.getRexBuilder(); + final List operands = + convertOperands(cx, call, call.getOperandList(), SqlOperandTypeChecker.Consistency.NONE); + final RelDataType type = cx.getValidator().getValidatedNodeType(call); + + // Create a CASE expression equivalent to the NVL2 function + // NVL2(x, y, z) is equivalent to CASE WHEN x IS NOT NULL THEN y ELSE z END + return rexBuilder.makeCall(type, SqlStdOperatorTable.CASE, + ImmutableList.of( + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, + operands.get(0)), + rexBuilder.makeCast( + cx.getTypeFactory() + .createTypeWithNullability(type, operands.get(1).getType().isNullable()), + operands.get(1)), + rexBuilder.makeCast( + cx.getTypeFactory() + .createTypeWithNullability(type, operands.get(2).getType().isNullable()), + operands.get(2)))); + } + /** Converts a call to the INSTR function. * INSTR(string, substring, position, occurrence) is equivalent to * POSITION(substring, string, position, occurrence) */ diff --git a/site/_docs/reference.md b/site/_docs/reference.md index 8a23f9338c2..4faca2f7e67 100644 --- a/site/_docs/reference.md +++ b/site/_docs/reference.md @@ -2812,6 +2812,7 @@ In the following: | b m p s | MD5(string) | Calculates an MD5 128-bit checksum of *string* and returns it as a hex string | m | MONTHNAME(date) | Returns the name, in the connection's locale, of the month in *datetime*; for example, it returns '二月' for both DATE '2020-02-10' and TIMESTAMP '2020-02-10 10:10:10' | o s | NVL(value1, value2) | Returns *value1* if *value1* is not null, otherwise *value2* +| o s | NVL2(value1, value2, value3) | Returns *value2* if *value1* is not null, otherwise *value3* | b | OFFSET(index) | When indexing an array, wrapping *index* in `OFFSET` returns the value at the 0-based *index*; throws error if *index* is out of bounds | b | ORDINAL(index) | Similar to `OFFSET` except *index* begins at 1 | b | PARSE_DATE(format, string) | Uses format specified by *format* to convert *string* representation of date to a DATE value diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 2ac1c50744b..daf139a75c3 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -10705,6 +10705,47 @@ void assertSubFunReturns(boolean binary, String s, int start, checkNvl(f, FunctionAlias.of(SqlLibraryOperators.NVL)); } + /** Test case for + * [CALCITE-6397] + * Add NVL2 function (enabled in Oracle, Spark library) . + */ + @Test void testNvl2Func() { + final SqlOperatorFixture f = fixture(); + f.setFor(SqlLibraryOperators.NVL2, VmName.EXPAND); + f.checkFails("^nvl2(NULL, 2, 1)^", + "No match found for function signature " + + "NVL2\\(, , \\)", false); + + final Consumer consumer = f12 -> { + f12.checkScalar("nvl2(NULL, 2, 1)", "1", "INTEGER NOT NULL"); + f12.checkScalar("nvl2(true, true, false)", true, "BOOLEAN NOT NULL"); + f12.checkScalar("nvl2(false, true, false)", true, "BOOLEAN NOT NULL"); + f12.checkScalar("nvl2(NULL, true, false)", false, "BOOLEAN NOT NULL"); + f12.checkScalar("nvl2(3, 2, 1)", "2", "INTEGER NOT NULL"); + f12.checkScalar("nvl2(3, 'a', 'b')", "a", "CHAR(1) NOT NULL"); + f12.checkScalar("nvl2(NULL, 'a', 'b')", "b", "CHAR(1) NOT NULL"); + f12.checkScalar("nvl2(NULL, 'ab', 'de')", "de", "CHAR(2) NOT NULL"); + f12.checkScalar("nvl2('ab', 'abc', 'def')", "abc", "CHAR(3) NOT NULL"); + f12.checkScalar("nvl2('a', 3, 2)", "3", "INTEGER NOT NULL"); + f12.checkScalar("NVL2(NULL, 3.0, 4.0)", "4.0", "DECIMAL(2, 1) NOT NULL"); + f12.checkScalar("NVL2('abc', 3.0, 4.0)", "3.0", "DECIMAL(2, 1) NOT NULL"); + f12.checkScalar("NVL2(1, 3.0, 2.111)", "3.0", "DECIMAL(4, 3) NOT NULL"); + f12.checkScalar("NVL2(NULL, 3.0, 2.111)", "2.111", "DECIMAL(4, 3) NOT NULL"); + f12.checkScalar("NVL2(3.111, 3.1415926, 2.111)", "3.1415926", "DECIMAL(8, 7) NOT NULL"); + + f12.checkNull("nvl2('ab', CAST(NULL AS VARCHAR(6)), 'def')"); + f12.checkNull("nvl2(NULL, 'abc', NULL)"); + f12.checkNull("nvl2(NULL, NULL, NULL)"); + + f12.checkFails("^NVL2(2.0, 2.0, true)^", "Parameters must be of the same type", false); + f12.checkFails("^NVL2(NULL, 2.0, true)^", "Parameters must be of the same type", false); + f12.checkFails("^NVL2(2.0, 1, true)^", "Parameters must be of the same type", false); + f12.checkFails("^NVL2(NULL, 1, true)^", "Parameters must be of the same type", false); + }; + f.forEachLibrary(list(SqlLibrary.ORACLE, SqlLibrary.SPARK), consumer); + + } + /** Tests the {@code NVL} and {@code IFNULL} operators. */ void checkNvl(SqlOperatorFixture f0, FunctionAlias functionAlias) { final SqlFunction function = functionAlias.function;