Skip to content

Commit

Permalink
[CALCITE-5918] Add MAP function (enabled in Spark library)
Browse files Browse the repository at this point in the history
  • Loading branch information
chucheng92 committed Oct 26, 2023
1 parent 5151168 commit 73ea5c7
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 10 deletions.
21 changes: 14 additions & 7 deletions core/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -4886,14 +4886,21 @@ SqlNode MapConstructor() :
{
<MAP> { s = span(); }
(
LOOKAHEAD(1)
<LPAREN>
// by sub query "MAP (SELECT empno, deptno FROM emp)"
e = LeafQueryOrExpr(ExprContext.ACCEPT_QUERY)
<RPAREN>
(
// empty map function call: "map()"
LOOKAHEAD(2)
<LPAREN> <RPAREN> { args = SqlNodeList.EMPTY; }
|
args = ParenthesizedQueryOrCommaList(ExprContext.ACCEPT_ALL)
)
{
return SqlStdOperatorTable.MAP_QUERY.createCall(
s.end(this), e);
if (args.size() == 1 && args.get(0).isA(SqlKind.QUERY)) {
// MAP query constructor e.g. "MAP (SELECT empno, deptno FROM emps)"
return SqlStdOperatorTable.MAP_QUERY.createCall(s.end(this), args.get(0));
} else {
// MAP function e.g. "MAP(1, 2)" equivalent to standard "MAP[1, 2]"
return SqlLibraryOperators.MAP.createCall(s.end(this), args.getList());
}
}
|
// by enumeration "MAP[k0, v0, ..., kN, vN]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_AND;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LOGICAL_OR;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.LPAD;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_CONCAT;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_ENTRIES;
import static org.apache.calcite.sql.fun.SqlLibraryOperators.MAP_FROM_ARRAYS;
Expand Down Expand Up @@ -880,6 +881,7 @@ Builder populate2() {
map.put(MAP_VALUE_CONSTRUCTOR, value);
map.put(ARRAY_VALUE_CONSTRUCTOR, value);
defineMethod(ARRAY, BuiltInMethod.ARRAYS_AS_LIST.method, NullPolicy.NONE);
defineMethod(MAP, BuiltInMethod.MAP.method, NullPolicy.NONE);

// ITEM operator
map.put(ITEM, new ItemImplementor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public <T> RelNode translate(Queryable<T> queryable) {
public RelNode translate(Expression expression) {
if (expression instanceof MethodCallExpression) {
final MethodCallExpression call = (MethodCallExpression) expression;
BuiltInMethod method = BuiltInMethod.MAP.get(call.method);
BuiltInMethod method = BuiltInMethod.FUNCTIONS_MAPS.get(call.method);
if (method == null) {
throw new UnsupportedOperationException(
"unknown method " + call.method);
Expand Down
14 changes: 14 additions & 0 deletions core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -5323,6 +5323,20 @@ public static Map mapFromArrays(List keysArray, List valuesArray) {
return map;
}

/** Support the MAP function.
*
* <p>odd-indexed elements are keys and even-indexed elements are values.
*/
public static Map map(Object... args) {
final Map map = new LinkedHashMap<>();
for (int i = 0; i < args.length; i += 2) {
Object key = args[i];
Object value = args[i + 1];
map.put(key, value);
}
return map;
}

/** Support the STR_TO_MAP function. */
public static Map strToMap(String string, String stringDelimiter, String keyValueDelimiter) {
final Map map = new LinkedHashMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Optionality;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Static;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;

Expand Down Expand Up @@ -1082,6 +1084,44 @@ private static RelDataType arrayReturnType(SqlOperatorBinding opBinding) {
SqlLibraryOperators::arrayReturnType,
OperandTypes.SAME_VARIADIC);

private static RelDataType mapReturnType(SqlOperatorBinding opBinding) {
Pair<@Nullable RelDataType, @Nullable RelDataType> type =
getComponentTypes(
opBinding.getTypeFactory(), opBinding.collectOperandTypes());
return SqlTypeUtil.createMapType(
opBinding.getTypeFactory(),
requireNonNull(type.left, "inferred key type"),
requireNonNull(type.right, "inferred value type"),
false);
}

private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes(
RelDataTypeFactory typeFactory,
List<RelDataType> argTypes) {
// special case, allows empty map
if (argTypes.isEmpty()) {
return Pair.of(typeFactory.createUnknownType(), typeFactory.createUnknownType());
}
// Util.quotientList(argTypes, 2, 0):
// This extracts all elements at even indices from argTypes.
// It represents the types of keys in the map as they are placed at even positions
// e.g. 0, 2, 4, etc.
// Symmetrically, Util.quotientList(argTypes, 2, 1) represents odd-indexed elements.
// details please see Util.quotientList.
return Pair.of(
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 0)),
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 1)));
}

/** The "MAP(key, value, ...)" function (Spark);
* compare with the standard map value constructor, "MAP[key, value, ...]". */
@LibraryOperator(libraries = {SPARK})
public static final SqlFunction MAP =
SqlBasicFunction.create("MAP",
SqlLibraryOperators::mapReturnType,
OperandTypes.MAP_FUNCTION,
SqlFunctionCategory.SYSTEM);

@SuppressWarnings("argument.type.incompatible")
private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding opBinding) {
final RelDataType arrayType = opBinding.collectOperandTypes().get(0);
Expand Down
60 changes: 60 additions & 0 deletions core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -568,6 +569,9 @@ public static SqlOperandTypeChecker variadic(
public static final SqlSingleOperandTypeChecker MAP_FROM_ENTRIES =
new MapFromEntriesOperandTypeChecker();

public static final SqlSingleOperandTypeChecker MAP_FUNCTION =
new MapFunctionOperandTypeChecker();

/**
* Operand type-checking strategy where type must be a literal or NULL.
*/
Expand Down Expand Up @@ -1221,6 +1225,62 @@ private static class MapFromEntriesOperandTypeChecker
}
}

/**
* Operand type-checking strategy for a MAP function, it allows empty map.
*/
private static class MapFunctionOperandTypeChecker
extends SameOperandTypeChecker {

MapFunctionOperandTypeChecker() {
// The args of map are non-fixed, so we set to -1 here. then operandCount
// can dynamically set according to the number of input args.
// details please see SameOperandTypeChecker#getOperandList.
super(-1);
}

@Override public boolean checkOperandTypes(final SqlCallBinding callBinding,
final boolean throwOnFailure) {
final List<RelDataType> argTypes =
SqlTypeUtil.deriveType(callBinding, callBinding.operands());
// allows empty map
if (argTypes.isEmpty()) {
return true;
}
// the size of map arg types must be even.
if (argTypes.size() % 2 != 0) {
throw callBinding.newValidationError(RESOURCE.mapRequiresEvenArgCount());
}
final Pair<@Nullable RelDataType, @Nullable RelDataType> componentType =
getComponentTypes(
callBinding.getTypeFactory(), argTypes);
// check key type & value type
if (null == componentType.left || null == componentType.right) {
if (throwOnFailure) {
throw callBinding.newValidationError(RESOURCE.needSameTypeParameter());
}
return false;
}
return true;
}

/**
* Extract the key type and value type of arg types.
*/
private static Pair<@Nullable RelDataType, @Nullable RelDataType> getComponentTypes(
RelDataTypeFactory typeFactory,
List<RelDataType> argTypes) {
// Util.quotientList(argTypes, 2, 0):
// This extracts all elements at even indices from argTypes.
// It represents the types of keys in the map as they are placed at even positions
// e.g. 0, 2, 4, etc.
// Symmetrically, Util.quotientList(argTypes, 2, 1) represents odd-indexed elements.
// details please see Util.quotientList.
return Pair.of(
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 0)),
typeFactory.leastRestrictive(Util.quotientList(argTypes, 2, 1)));
}
}

/** Operand type-checker that accepts period types. Examples:
*
* <ul>
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,7 @@ public enum BuiltInMethod {
ARRAYS_OVERLAP(SqlFunctions.class, "arraysOverlap", List.class, List.class),
ARRAYS_ZIP(SqlFunctions.class, "arraysZip", List.class, List.class),
SORT_ARRAY(SqlFunctions.class, "sortArray", List.class, boolean.class),
MAP(SqlFunctions.class, "map", Object[].class),
MAP_CONCAT(SqlFunctions.class, "mapConcat", Map[].class),
MAP_ENTRIES(SqlFunctions.class, "mapEntries", Map.class),
MAP_KEYS(SqlFunctions.class, "mapKeys", Map.class),
Expand Down Expand Up @@ -850,7 +851,7 @@ public enum BuiltInMethod {
@SuppressWarnings("ImmutableEnumChecker")
public final Field field;

public static final ImmutableMap<Method, BuiltInMethod> MAP;
public static final ImmutableMap<Method, BuiltInMethod> FUNCTIONS_MAPS;

static {
final ImmutableMap.Builder<Method, BuiltInMethod> builder =
Expand All @@ -860,7 +861,7 @@ public enum BuiltInMethod {
builder.put(value.method, value);
}
}
MAP = builder.build();
FUNCTIONS_MAPS = builder.build();
}

BuiltInMethod(@Nullable Method method, @Nullable Constructor constructor, @Nullable Field field) {
Expand Down
2 changes: 2 additions & 0 deletions site/_docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -2770,6 +2770,8 @@ BigQuery's type system uses confusingly different names for types and functions:
| b | TO_HEX(binary) | Converts *binary* into a hexadecimal varchar
| b | FROM_HEX(varchar) | Converts a hexadecimal-encoded *varchar* into bytes
| b o | LTRIM(string) | Returns *string* with all blanks removed from the start
| s | MAP() | Returns an empty map
| s | MAP(key, value [, key, value]*) | Returns a map with the given *key*/*value* pairs
| s | MAP_CONCAT(map [, map]*) | Concatenates one or more maps. If any input argument is `NULL` the function returns `NULL`. Note that calcite is using the LAST_WIN strategy
| s | MAP_ENTRIES(map) | Returns the entries of the *map* as an array, the order of the entries is not defined
| s | MAP_KEYS(map) | Returns the keys of the *map* as an array, the order of the entries is not defined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6146,6 +6146,41 @@ private static Matcher<SqlNode> isCharLiteral(String s) {
.ok("(MAP[])");
}

@Test void testMapFunction() {
expr("map()").ok("MAP()");
expr("MAP()").same();
// parser allows odd elements; validator will reject it
expr("map(1)").ok("MAP(1)");
expr("map(1, 'x', 2, 'y')")
.ok("MAP(1, 'x', 2, 'y')");
// with upper case
expr("MAP(1, 'x', 2, 'y')").same();
// with space
expr("map (1, 'x', 2, 'y')")
.ok("MAP(1, 'x', 2, 'y')");
}

@Test void testMapQueryConstructor() {
// parser allows odd elements; validator will reject it
sql("SELECT map(SELECT 1)")
.ok("SELECT (MAP ((SELECT 1)))");
sql("SELECT map(SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
// with upper case
sql("SELECT MAP(SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
// with space
sql("SELECT map (SELECT 1, 2)")
.ok("SELECT (MAP ((SELECT 1, 2)))");
sql("SELECT map(SELECT T.x, T.y FROM (VALUES(1, 2)) AS T(x, y))")
.ok("SELECT (MAP ((SELECT `T`.`X`, `T`.`Y`\n"
+ "FROM (VALUES (ROW(1, 2))) AS `T` (`X`, `Y`))))");
sql("SELECT map(1, ^SELECT^ x FROM (VALUES(1)) x)")
.fails("(?s)Incorrect syntax near the keyword 'SELECT'.*");
sql("SELECT map(SELECT x FROM (VALUES(1)) x, ^SELECT^ x FROM (VALUES(1)) x)")
.fails("(?s)Incorrect syntax near the keyword 'SELECT' at.*");
}

@Test void testVisitSqlInsertWithSqlShuttle() {
final String sql = "insert into emps select * from emps";
final SqlNode sqlNode = sql(sql).node();
Expand Down
Loading

0 comments on commit 73ea5c7

Please sign in to comment.