Skip to content

Commit

Permalink
[CALCITE-6041] MAP sub-query gives NullPointerException
Browse files Browse the repository at this point in the history
  • Loading branch information
chucheng92 committed Oct 18, 2023
1 parent c6031ca commit d894e59
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
Expand Down Expand Up @@ -88,39 +89,73 @@ public static Collect create(RelNode input, RelDataType rowType) {
getRowType(),
JavaRowFormat.LIST);

final SqlTypeName collectionType = getCollectionType();

// final Enumerable child = <<child adapter>>;
// final Enumerable<Object[]> converted = child.select(<<conversion code>>);
// final List<Object[]> list = converted.toList();
// if collectionType is ARRAY or MULTISET: final List<Object[]> list = converted.toList();
// if collectionType is MAP: final Map<Object, Object> map = converted.toMap();
Expression child_ =
builder.append(
"child", result.block);

RelDataType collectionComponentType =
requireNonNull(rowType().getFieldList().get(0).getType().getComponentType());
RelDataType childRecordType = result.physType.getRowType().getFieldList().get(0).getType();

Expression conv_ = child_;
if (!SqlTypeUtil.sameNamedType(collectionComponentType, childRecordType)) {
// In the internal representation of multisets , every element must be a record. In case the
// result above is a scalar type we have to wrap it around a physical type capable of
// representing records. For this reason the following conversion is necessary.
// REVIEW zabetak January 7, 2019: If we can ensure that the input to this operator
// has the correct physical type (e.g., respecting the Prefer.ARRAY above)
// then this conversion can be removed.
conv_ =
builder.append(
"converted", result.physType.convertTo(child_, JavaRowFormat.ARRAY));
}
Expression collectionExpr;
switch (collectionType) {
case ARRAY:
case MULTISET:
RelDataType collectionComponentType =
requireNonNull(rowType().getFieldList().get(0).getType().getComponentType());
RelDataType childRecordType = result.physType.getRowType().getFieldList().get(0).getType();

if (!SqlTypeUtil.sameNamedType(collectionComponentType, childRecordType)) {
// In the internal representation of multisets , every element must be a record. In case the
// result above is a scalar type we have to wrap it around a physical type capable of
// representing records. For this reason the following conversion is necessary.
// REVIEW zabetak January 7, 2019: If we can ensure that the input to this operator
// has the correct physical type (e.g., respecting the Prefer.ARRAY above)
// then this conversion can be removed.
conv_ =
builder.append(
"converted", result.physType.convertTo(child_, JavaRowFormat.ARRAY));
}

collectionExpr =
builder.append("list",
Expressions.call(conv_,
BuiltInMethod.ENUMERABLE_TO_LIST.method));
break;
case MAP:
// Convert input 'Object[]' to MAP data, we don't specify a comparator, just
// keep the original order of this map. (the inner map is a LinkedHashMap)
ParameterExpression input = Expressions.parameter(Object.class, "input");

Expression list_ =
builder.append("list",
Expressions.call(conv_,
BuiltInMethod.ENUMERABLE_TO_LIST.method));
// keySelector lambda: input -> ((Object[])input)[0]
Expression keySelector =
Expressions.lambda(
Expressions.arrayIndex(Expressions.convert_(input, Object[].class),
Expressions.constant(0)), input);

// valueSelector lambda: input -> ((Object[])input)[1]
Expression valueSelector =
Expressions.lambda(
Expressions.arrayIndex(Expressions.convert_(input, Object[].class),
Expressions.constant(1)), input);

collectionExpr =
builder.append("map",
Expressions.call(conv_,
BuiltInMethod.ENUMERABLE_TO_MAP.method, keySelector, valueSelector));
break;
default:
throw new IllegalArgumentException("unknown collection type " + collectionType);
}

builder.add(
Expressions.return_(null,
Expressions.call(
BuiltInMethod.SINGLETON_ENUMERABLE.method, list_)));
BuiltInMethod.SINGLETON_ENUMERABLE.method, collectionExpr)));

return implementor.result(physType, builder.toBlock());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ public static Collect create(RelNode input,
RelDataType rowType;
switch (sqlKind) {
case ARRAY_QUERY_CONSTRUCTOR:
case MAP_QUERY_CONSTRUCTOR:
case MULTISET_QUERY_CONSTRUCTOR:
rowType =
deriveRowType(input.getCluster().getTypeFactory(), collectionType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,20 @@
*/
package org.apache.calcite.sql.fun;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeTransforms;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;

import static org.apache.calcite.util.Static.RESOURCE;

/**
* Definition of the MAP query constructor, <code>
Expand All @@ -29,6 +41,43 @@ public class SqlMapQueryConstructor extends SqlMultisetQueryConstructor {
//~ Constructors -----------------------------------------------------------

public SqlMapQueryConstructor() {
super("MAP", SqlKind.MAP_QUERY_CONSTRUCTOR, SqlTypeTransforms.TO_MAP);
super("MAP", SqlKind.MAP_QUERY_CONSTRUCTOR, SqlTypeTransforms.TO_MAP_QUERY);
}

@Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
final List<RelDataType> argTypes = SqlTypeUtil.deriveType(callBinding, callBinding.operands());
if (argTypes.isEmpty()) {
throw callBinding.newValidationError(RESOURCE.mapRequiresTwoOrMoreArgs());
}
if (argTypes.size() % 2 != 0) {
throw callBinding.newValidationError(RESOURCE.mapRequiresEvenArgCount());
}
final Pair<@Nullable RelDataType, @Nullable RelDataType> componentType =
getComponentTypes(
callBinding.getTypeFactory(), argTypes);
if (null == componentType.left || null == componentType.right) {
if (throwOnFailure) {
throw callBinding.newValidationError(RESOURCE.needSameTypeParameter());
}
return false;
}
return true;
}

/**
* Extract the key type and value type of arg types.
*/
private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes(
RelDataTypeFactory typeFactory,
List<RelDataType> argTypes) {
// Util.quotientList(argTypes, 2, 0):
// This extracts all elements at even indices from argTypes.
// It represents the types of keys in the map as they are placed at even positions
// e.g. 0, 2, 4, etc.
// Symmetrically, Util.quotientList(argTypes, 2, 1) represents odd-indexed elements.
// details please see Util.quotientList.
return Pair.of(
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 0)),
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 1)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ private SqlTypeName toVar(RelDataType type) {
SqlTypeUtil.createMapTypeFromRecord(opBinding.getTypeFactory(),
typeToTransform);

/**
* Parameter type-inference transform strategy that wraps a given type in a map or
* wraps a field of the given type in a map. It is used when a map input is a sub-query.
*/
public static final SqlTypeTransform TO_MAP_QUERY =
(opBinding, typeToTransform) ->
TO_MAP.transformType(opBinding,
SqlTypeUtil.deriveCollectionQueryComponentType(SqlTypeName.MAP, typeToTransform));

/**
* Parameter type-inference transform strategy that converts a type to a MAP type,
* which key and value type is same.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ public enum BuiltInMethod {
AS_ENUMERABLE(Linq4j.class, "asEnumerable", Object[].class),
AS_ENUMERABLE2(Linq4j.class, "asEnumerable", Iterable.class),
ENUMERABLE_TO_LIST(ExtendedEnumerable.class, "toList"),
ENUMERABLE_TO_MAP(ExtendedEnumerable.class, "toMap", Function1.class, Function1.class),
AS_LIST(Primitive.class, "asList", Object.class),
MEMORY_GET0(MemoryFactory.Memory.class, "get"),
MEMORY_GET1(MemoryFactory.Memory.class, "get", int.class),
Expand Down
11 changes: 11 additions & 0 deletions core/src/test/resources/sql/sub-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -3685,4 +3685,15 @@ FROM dept d1;

!ok

# [CALCITE-6041] MAP sub-query gives NullPointerException
SELECT map(SELECT empno, deptno from emp where deptno < 20);
+-----------------------------+
| EXPR$0 |
+-----------------------------+
| {7782=10, 7839=10, 7934=10} |
+-----------------------------+
(1 row)

!ok

# End sub-query.iq
Original file line number Diff line number Diff line change
Expand Up @@ -6146,6 +6146,27 @@ private static Matcher<SqlNode> isCharLiteral(String s) {
.ok("(MAP[])");
}

@Test void testMapQueryConstructor() {
// parser allows odd elements; validator will reject it
sql("SELECT map(SELECT 1)")
.ok("SELECT (MAP ((SELECT 1)))");
sql("SELECT map(SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
// with upper case
sql("SELECT MAP(SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
// with space
sql("SELECT map (SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
sql("SELECT map(SELECT T.x, T.y FROM (VALUES(1, 2)) AS T(x, y))")
.ok("SELECT (MAP ((SELECT `T`.`X`, `T`.`Y`\n"
+ "FROM (VALUES (ROW(1, 2))) AS `T` (`X`, `Y`))))");
sql("SELECT map(^1^, SELECT x FROM (VALUES(1)) x)")
.fails("(?s)Non-query expression encountered in illegal context");
sql("SELECT map(SELECT x FROM (VALUES(1)) x, ^SELECT^ x FROM (VALUES(1)) x)")
.fails("(?s)Incorrect syntax near the keyword 'SELECT' at.*");
}

@Test void testVisitSqlInsertWithSqlShuttle() {
final String sql = "insert into emps select * from emps";
final SqlNode sqlNode = sql(sql).node();
Expand Down
34 changes: 34 additions & 0 deletions testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10446,6 +10446,40 @@ private static void checkArrayConcatAggFuncFails(SqlOperatorFixture t) {
"{1=1, 2=2}", "(BIGINT NOT NULL, SMALLINT NOT NULL) MAP NOT NULL");
}

@Test void testMapQueryConstructor() {
final SqlOperatorFixture f = Fixtures.forOperators(true);
f.setFor(SqlStdOperatorTable.MAP_QUERY, VmName.EXPAND);
// must be 2 fields
f.checkFails("map(select 1)", "MAP requires exactly two fields, got 1; "
+ "row type RecordType\\(INTEGER EXPR\\$0\\)", false);
f.checkFails("map(select 1, 2, 3)", "MAP requires exactly two fields, got 3; "
+ "row type RecordType\\(INTEGER EXPR\\$0, INTEGER EXPR\\$1, "
+ "INTEGER EXPR\\$2\\)", false);
f.checkFails("map(select 1, 'x', 2, 'x')", "MAP requires exactly two fields, got 4; "
+ "row type RecordType\\(INTEGER EXPR\\$0, CHAR\\(1\\) EXPR\\$1, INTEGER EXPR\\$2, "
+ "CHAR\\(1\\) EXPR\\$3\\)", false);
f.checkScalar("map(select 1, 2)", "{1=2}",
"(INTEGER NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
f.checkScalar("map(select 1, 2.0)", "{1=2.0}",
"(INTEGER NOT NULL, DECIMAL(2, 1) NOT NULL) MAP NOT NULL");
f.checkScalar("map(select 1, true)", "{1=true}",
"(INTEGER NOT NULL, BOOLEAN NOT NULL) MAP NOT NULL");
f.checkScalar("map(select 'x', 1)", "{x=1}",
"(CHAR(1) NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
// element cast
f.checkScalar("map(select cast(1 as bigint), 2)", "{1=2}",
"(BIGINT NOT NULL, INTEGER NOT NULL) MAP NOT NULL");
f.checkScalar("map(select 1, cast(2 as varchar))", "{1=2}",
"(INTEGER NOT NULL, VARCHAR NOT NULL) MAP NOT NULL");
// null key or value
f.checkScalar("map(select 1, null)", "{1=null}",
"(INTEGER NOT NULL, NULL) MAP NOT NULL");
f.checkScalar("map(select null, 1)", "{null=1}",
"(NULL, INTEGER NOT NULL) MAP NOT NULL");
f.checkScalar("map(select null, null)", "{null=null}",
"(NULL, NULL) MAP NOT NULL");
}

@Test void testCeilFunc() {
final SqlOperatorFixture f = fixture();
f.setFor(SqlStdOperatorTable.CEIL, VM_FENNEL);
Expand Down

0 comments on commit d894e59

Please sign in to comment.