Skip to content

Commit

Permalink
[CALCITE-6127] The spark array function gives NullPointerException wh…
Browse files Browse the repository at this point in the history
…en element is row type
  • Loading branch information
chucheng92 committed Dec 18, 2023
1 parent 08f6856 commit e066266
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1061,23 +1061,29 @@ static RelDataType deriveTypeSplit(SqlOperatorBinding operatorBinding,
private static RelDataType arrayReturnType(SqlOperatorBinding opBinding) {
final List<RelDataType> operandTypes = opBinding.collectOperandTypes();

// only numeric & character types check
// only numeric & character types check, this is a special spark array case
// the form like ARRAY(1, 2, '3') will return ["1", "2", "3"]
boolean hasNumeric = false;
boolean hasCharacter = false;
boolean hasOthers = false;
for (RelDataType type : operandTypes) {
SqlTypeFamily family = type.getSqlTypeName().getFamily();
requireNonNull(family, "array element type family");
// some types such as Row, the family is null, fallback to normal inferred type logic
if (family == null) {
hasOthers = true;
break;
}
// skip it because we allow NULL literal
if (SqlTypeUtil.isNull(type)) {
continue;
}
switch (family) {
case NUMERIC:
hasNumeric = true;
break;
case CHARACTER:
hasCharacter = true;
break;
case NULL:
// skip it becase we allow null
break;
default:
hasOthers = true;
break;
Expand Down Expand Up @@ -1113,7 +1119,7 @@ private static RelDataType arrayReturnType(SqlOperatorBinding opBinding) {
public static final SqlFunction ARRAY =
SqlBasicFunction.create("ARRAY",
SqlLibraryOperators::arrayReturnType,
OperandTypes.SAME_VARIADIC,
OperandTypes.ARRAY_FUNCTION,
SqlFunctionCategory.SYSTEM);

private static RelDataType mapReturnType(SqlOperatorBinding opBinding) {
Expand Down
56 changes: 56 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.util.ImmutableIntList;
Expand Down Expand Up @@ -560,6 +561,9 @@ public static SqlOperandTypeChecker variadic(
public static final SqlSingleOperandTypeChecker MAP =
family(SqlTypeFamily.MAP);

public static final SqlOperandTypeChecker ARRAY_FUNCTION =
new ArrayFunctionOperandTypeChecker();

public static final SqlOperandTypeChecker ARRAY_ELEMENT =
new ArrayElementOperandTypeChecker();

Expand Down Expand Up @@ -1225,6 +1229,58 @@ private static class MapFromEntriesOperandTypeChecker
}
}

/**
* Operand type-checking strategy for a ARRAY function, it allows empty array.
*/
private static class ArrayFunctionOperandTypeChecker
extends SameOperandTypeChecker {

ArrayFunctionOperandTypeChecker() {
// The args of array are non-fixed, so we set to -1 here. then operandCount
// can dynamically set according to the number of input args.
// details please see SameOperandTypeChecker#getOperandList.
super(-1);
}

@Override protected boolean checkOperandTypesImpl(
SqlOperatorBinding operatorBinding,
boolean throwOnFailure,
@Nullable SqlCallBinding callBinding) {
if (throwOnFailure && callBinding == null) {
throw new IllegalArgumentException(
"callBinding must be non-null in case throwOnFailure=true");
}
int nOperandsActual = nOperands;
if (nOperandsActual == -1) {
nOperandsActual = operatorBinding.getOperandCount();
}
RelDataType[] types = new RelDataType[nOperandsActual];
final List<Integer> operandList =
getOperandList(operatorBinding.getOperandCount());
for (int i : operandList) {
types[i] = operatorBinding.getOperandType(i);
}
for (int i : operandList) {
if (i > 0) {
// we replace SqlTypeUtil.isComparable with SqlTypeUtil.leastRestrictiveForComparison
// to handle struct type and NULL constant.
// details please see: https://issues.apache.org/jira/browse/CALCITE-6163
RelDataType type =
SqlTypeUtil.leastRestrictiveForComparison(operatorBinding.getTypeFactory(),
types[i], types[i - 1]);
if (type == null) {
if (!throwOnFailure) {
return false;
}
throw requireNonNull(callBinding, "callBinding").newValidationError(
RESOURCE.needSameTypeParameter());
}
}
}
return true;
}
}

/**
* Operand type-checking strategy for a MAP function, it allows empty map.
*/
Expand Down
28 changes: 28 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10537,6 +10537,34 @@ private static void checkArrayConcatAggFuncFails(SqlOperatorFixture t) {
"[null, foo]", "CHAR(3) ARRAY NOT NULL");
f2.checkScalar("array(null)",
"[null]", "NULL ARRAY NOT NULL");
// check complex type
f2.checkScalar("array(row(1))", "[{1}]",
"RecordType(INTEGER NOT NULL EXPR$0) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(1, null))", "[{1, null}]",
"RecordType(INTEGER NOT NULL EXPR$0, NULL EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(null, 1))", "[{null, 1}]",
"RecordType(NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(1, 2))", "[{1, 2}]",
"RecordType(INTEGER NOT NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(1, 2), null)",
"[{1, 2}, null]", "RecordType(INTEGER EXPR$0, INTEGER EXPR$1) ARRAY NOT NULL");
f2.checkScalar("array(null, row(1, 2))",
"[null, {1, 2}]", "RecordType(INTEGER EXPR$0, INTEGER EXPR$1) ARRAY NOT NULL");
f2.checkScalar("array(row(1, null), row(2, null))", "[{1, null}, {2, null}]",
"RecordType(INTEGER NOT NULL EXPR$0, NULL EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(null, 1), row(null, 2))", "[{null, 1}, {null, 2}]",
"RecordType(NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(1, null), row(null, 2))", "[{1, null}, {null, 2}]",
"RecordType(INTEGER EXPR$0, INTEGER EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(null, 1), row(2, null))", "[{null, 1}, {2, null}]",
"RecordType(INTEGER EXPR$0, INTEGER EXPR$1) NOT NULL ARRAY NOT NULL");
f2.checkScalar("array(row(1, 2), row(3, 4))", "[{1, 2}, {3, 4}]",
"RecordType(INTEGER NOT NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL ARRAY NOT NULL");
// checkFails
f2.checkFails("^array(row(1), row(2, 3))^",
"Parameters must be of the same type", false);
f2.checkFails("^array(row(1), row(2, 3), null)^",
"Parameters must be of the same type", false);
// calcite default cast char type will fill extra spaces
f2.checkScalar("array(1, 2, 'Hi')",
"[1 , 2 , Hi]", "CHAR(2) NOT NULL ARRAY NOT NULL");
Expand Down

0 comments on commit e066266

Please sign in to comment.