Skip to content

Commit

Permalink
[#FLINK-33596] Fold expression before transfer function to RexNode.
Browse files Browse the repository at this point in the history
  • Loading branch information
yunfan123 committed Nov 28, 2023
1 parent 15a3372 commit a76a034
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.flink.util.StringUtils;

import org.apache.hadoop.hive.ql.exec.FunctionInfo;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFFirstValue;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -138,6 +139,7 @@ public Set<String> listFunctions() {
functionNames.add(GenericUDFLegacyGroupingID.NAME);
functionNames.add(HiveGenericUDFArrayAccessStructField.NAME);
functionNames.add(HiveGenericUDFToDecimal.NAME);
functionNames.add("first");
}
return functionNames;
}
Expand All @@ -160,6 +162,11 @@ public Optional<FunctionDefinition> getFunctionDefinition(String name) {
factory.createFunctionDefinitionFromHiveFunction(
name, HiveGenericUDFGrouping.class.getName(), context));
}
if (name.equalsIgnoreCase("first")) {
return Optional.of(
factory.createFunctionDefinitionFromHiveFunction(
name, GenericUDAFFirstValue.class.getName(), context));
}

// this function is used to generate legacy GROUPING__ID value for old hive versions
if (name.equalsIgnoreCase(GenericUDFLegacyGroupingID.NAME)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.processors.HiveCommand;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFFirstValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -236,6 +238,7 @@ public List<Operation> parse(String statement) {
HiveSessionState.startSessionState(hiveConfCopy, catalogRegistry);
// We override Hive's grouping function. Refer to the implementation for more details.
hiveShim.registerTemporaryFunction("grouping", HiveGenericUDFGrouping.class);
hiveShim.registerTemporaryFunction("first", GenericUDAFFirstValue.class);
return processCmd(statement, hiveConfCopy, hiveShim, currentCatalog);
} finally {
HiveSessionState.clearSessionState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import org.apache.hadoop.hive.common.type.HiveVarchar;
import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
import org.apache.hadoop.hive.ql.optimizer.ConstantPropagateProcFactory;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
Expand Down Expand Up @@ -211,7 +212,7 @@ public static RexNode convert(

public RexNode convert(ExprNodeDesc expr) throws SemanticException {
if (expr instanceof ExprNodeGenericFuncDesc) {
return convertGenericFunc((ExprNodeGenericFuncDesc) expr);
return convertGenericFunc((ExprNodeGenericFuncDesc) expr, cluster);
} else if (expr instanceof ExprNodeConstantDesc) {
return convertConstant((ExprNodeConstantDesc) expr, cluster);
} else if (expr instanceof ExprNodeColumnDesc) {
Expand Down Expand Up @@ -518,13 +519,18 @@ public static RexNode convertConstant(ExprNodeConstantDesc literal, RelOptCluste
return calciteLiteral;
}

private RexNode convertGenericFunc(ExprNodeGenericFuncDesc func) throws SemanticException {
private RexNode convertGenericFunc(ExprNodeGenericFuncDesc func, RelOptCluster cluster)
throws SemanticException {
ExprNodeDesc tmpExprNode;
RexNode tmpRN;

List<RexNode> childRexNodeLst = new ArrayList<>();
List<RelDataType> argTypes = new ArrayList<>();

ExprNodeDesc afterFoldDesc = ConstantPropagateProcFactory.foldExpr(func);
if (afterFoldDesc instanceof ExprNodeConstantDesc) {
return convertConstant((ExprNodeConstantDesc) afterFoldDesc, cluster);
}
// TODO: 1) Expand to other functions as needed 2) What about types other than primitive.
TypeInfo tgtDT = null;
GenericUDF tgtUdf = func.getGenericUDF();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ public static void setup() throws Exception {
// create functions
tableEnv.executeSql(
"create function hiveudf as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs'");
tableEnv.executeSql(
"create function hiveudaf as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList'");
tableEnv.executeSql(
"create function hiveudtf as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDTFExplode'");
tableEnv.executeSql("create function myudtf as '" + MyUDTF.class.getName() + "'");
Expand All @@ -180,6 +182,13 @@ public void testQueries() throws Exception {
}
}

@Test
public void testCommonTest() throws Exception {
tableEnv.executeSql("select first_value(id, true) over (partition by name order by dep) from employee")
.collect().forEachRemaining(
r -> System.out.println(r.toString()));
}

@Test
public void testAdditionalQueries() throws Exception {
List<String> toRun =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ select dep,count(1) from employee where salary<5000 and age>=38 and dep='Sales'
select x,null as n from foo group by x,'a',null;

[+I[1, null], +I[2, null], +I[3, null], +I[4, null], +I[5, null]]

select dep, sum(salary) from employee group by dep, UNIX_TIMESTAMP();

[+I[Management, 12900], +I[Production, 18600], +I[Sales, 8400], +I[Service, 4100]]

0 comments on commit a76a034

Please sign in to comment.