diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCollect.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCollect.java index 64ead3316708..3d5608e9dea7 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCollect.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableCollect.java @@ -19,6 +19,7 @@ import org.apache.calcite.linq4j.tree.BlockBuilder; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.linq4j.tree.ParameterExpression; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelNode; @@ -88,39 +89,71 @@ public static Collect create(RelNode input, RelDataType rowType) { getRowType(), JavaRowFormat.LIST); + final SqlTypeName collectionType = getCollectionType(); + // final Enumerable child = <>; // final Enumerable converted = child.select(<>); - // final List list = converted.toList(); + // if collectionType is ARRAY or MULTISET: final List list = converted.toList(); + // if collectionType is MAP: final Map map = converted.toMap(); Expression child_ = builder.append( "child", result.block); - RelDataType collectionComponentType = - requireNonNull(rowType().getFieldList().get(0).getType().getComponentType()); - RelDataType childRecordType = result.physType.getRowType().getFieldList().get(0).getType(); - Expression conv_ = child_; - if (!SqlTypeUtil.sameNamedType(collectionComponentType, childRecordType)) { - // In the internal representation of multisets , every element must be a record. In case the - // result above is a scalar type we have to wrap it around a physical type capable of - // representing records. For this reason the following conversion is necessary. - // REVIEW zabetak January 7, 2019: If we can ensure that the input to this operator - // has the correct physical type (e.g., respecting the Prefer.ARRAY above) - // then this conversion can be removed. - conv_ = - builder.append( - "converted", result.physType.convertTo(child_, JavaRowFormat.ARRAY)); - } + Expression collectionExpr; + switch (collectionType) { + case ARRAY: + case MULTISET: + RelDataType collectionComponentType = + requireNonNull(rowType().getFieldList().get(0).getType().getComponentType()); + RelDataType childRecordType = result.physType.getRowType().getFieldList().get(0).getType(); + + if (!SqlTypeUtil.sameNamedType(collectionComponentType, childRecordType)) { + // In the internal representation of multisets , every element must be a record. In case the + // result above is a scalar type we have to wrap it around a physical type capable of + // representing records. For this reason the following conversion is necessary. + // REVIEW zabetak January 7, 2019: If we can ensure that the input to this operator + // has the correct physical type (e.g., respecting the Prefer.ARRAY above) + // then this conversion can be removed. + conv_ = + builder.append( + "converted", result.physType.convertTo(child_, JavaRowFormat.ARRAY)); + } - Expression list_ = - builder.append("list", - Expressions.call(conv_, - BuiltInMethod.ENUMERABLE_TO_LIST.method)); + collectionExpr = + builder.append("list", + Expressions.call(conv_, + BuiltInMethod.ENUMERABLE_TO_LIST.method)); + break; + case MAP: + // Convert input 'Object[]' to MAP data, we don't specify comparator, just + // keep the original order. + ParameterExpression input = Expressions.parameter(Object.class, "input"); + + // keySelector lambda: input -> ((Object[])input)[0] + Expression keySelector = + Expressions.lambda( + Expressions.arrayIndex(Expressions.convert_(input, Object[].class), + Expressions.constant(0)), input); + // valueSelector lambda: input -> ((Object[])input)[1] + Expression valueSelector = + Expressions.lambda( + Expressions.arrayIndex(Expressions.convert_(input, Object[].class), + Expressions.constant(1)), input); + collectionExpr = + builder.append("map", + Expressions.call(conv_, + BuiltInMethod.ENUMERABLE_TO_MAP.method, keySelector, valueSelector)); + break; + default: + throw new IllegalArgumentException("unknown collection type " + collectionType); + } builder.add( Expressions.return_(null, Expressions.call( - BuiltInMethod.SINGLETON_ENUMERABLE.method, list_))); + BuiltInMethod.SINGLETON_ENUMERABLE.method, collectionExpr))); + return implementor.result(physType, builder.toBlock()); } } diff --git a/core/src/main/java/org/apache/calcite/rel/core/Collect.java b/core/src/main/java/org/apache/calcite/rel/core/Collect.java index 6d5d06a2cf14..88e2cc7927e9 100644 --- a/core/src/main/java/org/apache/calcite/rel/core/Collect.java +++ b/core/src/main/java/org/apache/calcite/rel/core/Collect.java @@ -148,6 +148,7 @@ public static Collect create(RelNode input, RelDataType rowType; switch (sqlKind) { case ARRAY_QUERY_CONSTRUCTOR: + case MAP_QUERY_CONSTRUCTOR: case MULTISET_QUERY_CONSTRUCTOR: rowType = deriveRowType(input.getCluster().getTypeFactory(), collectionType, diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlMapQueryConstructor.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlMapQueryConstructor.java index bb8840889d11..97bc5b280682 100644 --- a/core/src/main/java/org/apache/calcite/sql/fun/SqlMapQueryConstructor.java +++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlMapQueryConstructor.java @@ -16,8 +16,20 @@ */ package org.apache.calcite.sql.fun; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlCallBinding; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.SqlTypeTransforms; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.util.List; + +import static org.apache.calcite.util.Static.RESOURCE; /** * Definition of the MAP query constructor, @@ -29,6 +41,43 @@ public class SqlMapQueryConstructor extends SqlMultisetQueryConstructor { //~ Constructors ----------------------------------------------------------- public SqlMapQueryConstructor() { - super("MAP", SqlKind.MAP_QUERY_CONSTRUCTOR, SqlTypeTransforms.TO_MAP); + super("MAP", SqlKind.MAP_QUERY_CONSTRUCTOR, SqlTypeTransforms.TO_MAP_QUERY); + } + + @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + final List argTypes = SqlTypeUtil.deriveType(callBinding, callBinding.operands()); + if (argTypes.isEmpty()) { + throw callBinding.newValidationError(RESOURCE.mapRequiresTwoOrMoreArgs()); + } + if (argTypes.size() % 2 != 0) { + throw callBinding.newValidationError(RESOURCE.mapRequiresEvenArgCount()); + } + final Pair<@Nullable RelDataType, @Nullable RelDataType> componentType = + getComponentTypes( + callBinding.getTypeFactory(), argTypes); + if (null == componentType.left || null == componentType.right) { + if (throwOnFailure) { + throw callBinding.newValidationError(RESOURCE.needSameTypeParameter()); + } + return false; + } + return true; + } + + /** + * Extract the key type and value type of arg types. + */ + private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes( + RelDataTypeFactory typeFactory, + List argTypes) { + // Util.quotientList(argTypes, 2, 0): + // This extracts all elements at even indices from argTypes. + // It represents the types of keys in the map as they are placed at even positions + // e.g. 0, 2, 4, etc. + // Symmetrically, Util.quotientList(argTypes, 2, 1) represents odd-indexed elements. + // details please see Util.quotientList. + return Pair.of( + typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 0)), + typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 1))); } } diff --git a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeTransforms.java b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeTransforms.java index ad91c54b99c6..d13945a3a8f8 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/SqlTypeTransforms.java +++ b/core/src/main/java/org/apache/calcite/sql/type/SqlTypeTransforms.java @@ -273,6 +273,15 @@ private SqlTypeName toVar(RelDataType type) { SqlTypeUtil.createMapTypeFromRecord(opBinding.getTypeFactory(), typeToTransform); + /** + * Parameter type-inference transform strategy that wraps a given type in a map or + * wraps a field of the given type in a map. It is used when a map input is a sub-query. + */ + public static final SqlTypeTransform TO_MAP_QUERY = + (opBinding, typeToTransform) -> + TO_MAP.transformType(opBinding, + SqlTypeUtil.deriveCollectionQueryComponentType(SqlTypeName.MAP, typeToTransform)); + /** * Parameter type-inference transform strategy that converts a type to a MAP type, * which key and value type is same. diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index f0df06185201..0d464d70d7c6 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -288,6 +288,7 @@ public enum BuiltInMethod { AS_ENUMERABLE(Linq4j.class, "asEnumerable", Object[].class), AS_ENUMERABLE2(Linq4j.class, "asEnumerable", Iterable.class), ENUMERABLE_TO_LIST(ExtendedEnumerable.class, "toList"), + ENUMERABLE_TO_MAP(ExtendedEnumerable.class, "toMap", Function1.class, Function1.class), AS_LIST(Primitive.class, "asList", Object.class), MEMORY_GET0(MemoryFactory.Memory.class, "get"), MEMORY_GET1(MemoryFactory.Memory.class, "get", int.class), diff --git a/core/src/test/resources/sql/sub-query.iq b/core/src/test/resources/sql/sub-query.iq index 42bc10884ad1..aa43c54d923f 100644 --- a/core/src/test/resources/sql/sub-query.iq +++ b/core/src/test/resources/sql/sub-query.iq @@ -3685,4 +3685,15 @@ FROM dept d1; !ok +# [CALCITE-6041] MAP sub-query gives NullPointerException +SELECT map(SELECT empno, deptno from emp where deptno < 20); ++-----------------------------+ +| EXPR$0 | ++-----------------------------+ +| {7782=10, 7839=10, 7934=10} | ++-----------------------------+ +(1 row) + +!ok + # End sub-query.iq diff --git a/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java b/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java index 2086a3000282..82fd43ec08bc 100644 --- a/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java +++ b/testkit/src/main/java/org/apache/calcite/sql/parser/SqlParserTest.java @@ -6146,6 +6146,27 @@ private static Matcher isCharLiteral(String s) { .ok("(MAP[])"); } + @Test void testMapQueryConstructor() { + // parser allows odd elements; validator will reject it + sql("SELECT map(SELECT 1)") + .ok("SELECT (MAP ((SELECT 1)))"); + sql("SELECT map(SELECT 1, 2)") + .ok("SELECT (MAP ((SELECT 1, 2)))"); + // with upper case + sql("SELECT MAP(SELECT 1, 2)") + .ok("SELECT (MAP ((SELECT 1, 2)))"); + // with space + sql("SELECT map (SELECT 1, 2)") + .ok("SELECT (MAP ((SELECT 1, 2)))"); + sql("SELECT map(SELECT T.x, T.y FROM (VALUES(1, 2)) AS T(x, y))") + .ok("SELECT (MAP ((SELECT `T`.`X`, `T`.`Y`\n" + + "FROM (VALUES (ROW(1, 2))) AS `T` (`X`, `Y`))))"); + sql("SELECT map(^1^, SELECT x FROM (VALUES(1)) x)") + .fails("(?s)Non-query expression encountered in illegal context"); + sql("SELECT map(SELECT x FROM (VALUES(1)) x, ^SELECT^ x FROM (VALUES(1)) x)") + .fails("(?s)Incorrect syntax near the keyword 'SELECT' at.*"); + } + @Test void testVisitSqlInsertWithSqlShuttle() { final String sql = "insert into emps select * from emps"; final SqlNode sqlNode = sql(sql).node(); diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 746c3307203e..1ffd3164ac15 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -10446,6 +10446,40 @@ private static void checkArrayConcatAggFuncFails(SqlOperatorFixture t) { "{1=1, 2=2}", "(BIGINT NOT NULL, SMALLINT NOT NULL) MAP NOT NULL"); } + @Test void testMapQueryConstructor() { + final SqlOperatorFixture f = Fixtures.forOperators(true); + f.setFor(SqlStdOperatorTable.MAP_QUERY, VmName.EXPAND); + // must be 2 fields + f.checkFails("map(select 1)", "MAP requires exactly two fields, got 1; " + + "row type RecordType\\(INTEGER EXPR\\$0\\)", false); + f.checkFails("map(select 1, 2, 3)", "MAP requires exactly two fields, got 3; " + + "row type RecordType\\(INTEGER EXPR\\$0, INTEGER EXPR\\$1, " + + "INTEGER EXPR\\$2\\)", false); + f.checkFails("map(select 1, 'x', 2, 'x')", "MAP requires exactly two fields, got 4; " + + "row type RecordType\\(INTEGER EXPR\\$0, CHAR\\(1\\) EXPR\\$1, INTEGER EXPR\\$2, " + + "CHAR\\(1\\) EXPR\\$3\\)", false); + f.checkScalar("map(select 1, 2)", "{1=2}", + "(INTEGER NOT NULL, INTEGER NOT NULL) MAP NOT NULL"); + f.checkScalar("map(select 1, 2.0)", "{1=2.0}", + "(INTEGER NOT NULL, DECIMAL(2, 1) NOT NULL) MAP NOT NULL"); + f.checkScalar("map(select 1, true)", "{1=true}", + "(INTEGER NOT NULL, BOOLEAN NOT NULL) MAP NOT NULL"); + f.checkScalar("map(select 'x', 1)", "{x=1}", + "(CHAR(1) NOT NULL, INTEGER NOT NULL) MAP NOT NULL"); + // element cast + f.checkScalar("map(select cast(1 as bigint), 2)", "{1=2}", + "(BIGINT NOT NULL, INTEGER NOT NULL) MAP NOT NULL"); + f.checkScalar("map(select 1, cast(2 as varchar))", "{1=2}", + "(INTEGER NOT NULL, VARCHAR NOT NULL) MAP NOT NULL"); + // null key or value + f.checkScalar("map(select 1, null)", "{1=null}", + "(INTEGER NOT NULL, NULL) MAP NOT NULL"); + f.checkScalar("map(select null, 1)", "{null=1}", + "(NULL, INTEGER NOT NULL) MAP NOT NULL"); + f.checkScalar("map(select null, null)", "{null=null}", + "(NULL, NULL) MAP NOT NULL"); + } + @Test void testCeilFunc() { final SqlOperatorFixture f = fixture(); f.setFor(SqlStdOperatorTable.CEIL, VM_FENNEL);