From a76a034db1939a11fb1074c081d43f8c0256a378 Mon Sep 17 00:00:00 2001 From: zhangyunfan Date: Mon, 20 Nov 2023 18:00:43 +0800 Subject: [PATCH] [#FLINK-33596] Fold expression before transfer function to RexNode. --- .../org/apache/flink/table/module/hive/HiveModule.java | 7 +++++++ .../table/planner/delegation/hive/HiveParser.java | 3 +++ .../delegation/hive/HiveParserRexNodeConverter.java | 10 ++++++++-- .../flink/connectors/hive/HiveDialectQueryITCase.java | 9 +++++++++ .../src/test/resources/query-test/group_by.q | 4 ++++ 5 files changed, 31 insertions(+), 2 deletions(-) diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java index a45739a6cd80d..90e517c43ec72 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java @@ -40,6 +40,7 @@ import org.apache.flink.util.StringUtils; import org.apache.hadoop.hive.ql.exec.FunctionInfo; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFFirstValue; import java.util.Arrays; import java.util.Collections; @@ -138,6 +139,7 @@ public Set listFunctions() { functionNames.add(GenericUDFLegacyGroupingID.NAME); functionNames.add(HiveGenericUDFArrayAccessStructField.NAME); functionNames.add(HiveGenericUDFToDecimal.NAME); + functionNames.add("first"); } return functionNames; } @@ -160,6 +162,11 @@ public Optional getFunctionDefinition(String name) { factory.createFunctionDefinitionFromHiveFunction( name, HiveGenericUDFGrouping.class.getName(), context)); } + if (name.equalsIgnoreCase("first")) { + return Optional.of( + factory.createFunctionDefinitionFromHiveFunction( + name, GenericUDAFFirstValue.class.getName(), context)); + } // this function is used to generate legacy GROUPING__ID value for old hive versions if (name.equalsIgnoreCase(GenericUDFLegacyGroupingID.NAME)) { diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java index 349ecef7f0c4e..dde503f91bc9d 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java @@ -70,6 +70,8 @@ import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.processors.HiveCommand; import org.apache.hadoop.hive.ql.session.SessionState; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFFirstValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -236,6 +238,7 @@ public List parse(String statement) { HiveSessionState.startSessionState(hiveConfCopy, catalogRegistry); // We override Hive's grouping function. Refer to the implementation for more details. hiveShim.registerTemporaryFunction("grouping", HiveGenericUDFGrouping.class); + hiveShim.registerTemporaryFunction("first", GenericUDAFFirstValue.class); return processCmd(statement, hiveConfCopy, hiveShim, currentCatalog); } finally { HiveSessionState.clearSessionState(); diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java index 5c5f67aa68e1f..41ff4b9b07bc7 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParserRexNodeConverter.java @@ -65,6 +65,7 @@ import org.apache.hadoop.hive.common.type.HiveVarchar; import org.apache.hadoop.hive.ql.ErrorMsg; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; +import org.apache.hadoop.hive.ql.optimizer.ConstantPropagateProcFactory; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; @@ -211,7 +212,7 @@ public static RexNode convert( public RexNode convert(ExprNodeDesc expr) throws SemanticException { if (expr instanceof ExprNodeGenericFuncDesc) { - return convertGenericFunc((ExprNodeGenericFuncDesc) expr); + return convertGenericFunc((ExprNodeGenericFuncDesc) expr, cluster); } else if (expr instanceof ExprNodeConstantDesc) { return convertConstant((ExprNodeConstantDesc) expr, cluster); } else if (expr instanceof ExprNodeColumnDesc) { @@ -518,13 +519,18 @@ public static RexNode convertConstant(ExprNodeConstantDesc literal, RelOptCluste return calciteLiteral; } - private RexNode convertGenericFunc(ExprNodeGenericFuncDesc func) throws SemanticException { + private RexNode convertGenericFunc(ExprNodeGenericFuncDesc func, RelOptCluster cluster) + throws SemanticException { ExprNodeDesc tmpExprNode; RexNode tmpRN; List childRexNodeLst = new ArrayList<>(); List argTypes = new ArrayList<>(); + ExprNodeDesc afterFoldDesc = ConstantPropagateProcFactory.foldExpr(func); + if (afterFoldDesc instanceof ExprNodeConstantDesc) { + return convertConstant((ExprNodeConstantDesc) afterFoldDesc, cluster); + } // TODO: 1) Expand to other functions as needed 2) What about types other than primitive. TypeInfo tgtDT = null; GenericUDF tgtUdf = func.getGenericUDF(); diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java index 2f568d0d00f3d..5720bf538c818 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java @@ -163,6 +163,8 @@ public static void setup() throws Exception { // create functions tableEnv.executeSql( "create function hiveudf as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs'"); + tableEnv.executeSql( + "create function hiveudaf as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList'"); tableEnv.executeSql( "create function hiveudtf as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDTFExplode'"); tableEnv.executeSql("create function myudtf as '" + MyUDTF.class.getName() + "'"); @@ -180,6 +182,13 @@ public void testQueries() throws Exception { } } + @Test + public void testCommonTest() throws Exception { + tableEnv.executeSql("select first_value(id, true) over (partition by name order by dep) from employee") + .collect().forEachRemaining( + r -> System.out.println(r.toString())); + } + @Test public void testAdditionalQueries() throws Exception { List toRun = diff --git a/flink-connectors/flink-connector-hive/src/test/resources/query-test/group_by.q b/flink-connectors/flink-connector-hive/src/test/resources/query-test/group_by.q index 28f2bc7785a2e..77e97e85f95f5 100644 --- a/flink-connectors/flink-connector-hive/src/test/resources/query-test/group_by.q +++ b/flink-connectors/flink-connector-hive/src/test/resources/query-test/group_by.q @@ -31,3 +31,7 @@ select dep,count(1) from employee where salary<5000 and age>=38 and dep='Sales' select x,null as n from foo group by x,'a',null; [+I[1, null], +I[2, null], +I[3, null], +I[4, null], +I[5, null]] + +select dep, sum(salary) from employee group by dep, UNIX_TIMESTAMP(); + +[+I[Management, 12900], +I[Production, 18600], +I[Sales, 8400], +I[Service, 4100]]