Skip to content

Commit f2527e1

Browse files
committed
chore(isthmus): handle nullability and EnumArgument in SimplExtensionToSqlOperator
1 parent f3a5f4d commit f2527e1

File tree

6 files changed

+94
-36
lines changed

6 files changed

+94
-36
lines changed

isthmus/src/main/java/io/substrait/isthmus/SimpleExtensionToSqlOperator.java

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.substrait.function.TypeExpression;
77
import io.substrait.type.Type;
88
import io.substrait.type.TypeExpressionEvaluator;
9+
import java.util.ArrayList;
910
import java.util.List;
1011
import java.util.stream.Collectors;
1112
import java.util.stream.Stream;
@@ -46,17 +47,19 @@ private static SqlFunction toSqlFunction(
4647
SimpleExtension.Function function,
4748
RelDataTypeFactory typeFactory,
4849
TypeConverter typeConverter) {
49-
List<SimpleExtension.Argument> requiredArgs =
50-
function.args().stream()
51-
.filter(SimpleExtension.Argument::required)
52-
.filter(t -> t instanceof SimpleExtension.ValueArgument || t instanceof SimpleExtension.EnumArgument)
53-
.map(t -> (SimpleExtension.Argument) t)
54-
.collect(Collectors.toList());
55-
56-
List<SqlTypeFamily> argFamilies =
57-
requiredArgs.stream()
58-
.map(arg -> arg.value().accept(new CalciteTypeVisitor()).getFamily())
59-
.collect(Collectors.toList());
50+
51+
List<SqlTypeFamily> argFamilies = new ArrayList<>();
52+
53+
for (SimpleExtension.Argument arg : function.requiredArguments()) {
54+
if (arg instanceof SimpleExtension.ValueArgument) {
55+
SimpleExtension.ValueArgument valueArg = (SimpleExtension.ValueArgument) arg;
56+
SqlTypeName typeName = valueArg.value().accept(new CalciteTypeVisitor());
57+
argFamilies.add(typeName.getFamily());
58+
} else if (arg instanceof SimpleExtension.EnumArgument) {
59+
// Treat an EnumArgument as a required string literal.
60+
argFamilies.add(SqlTypeFamily.STRING);
61+
}
62+
}
6063

6164
SqlReturnTypeInference returnTypeInference =
6265
new SubstraitReturnTypeInference(function, typeFactory, typeConverter);
@@ -97,7 +100,27 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
97100
TypeExpressionEvaluator.evaluateExpression(
98101
returnExpression, function.args(), substraitArgTypes);
99102

100-
return typeConverter.toCalcite(typeFactory, resolvedSubstraitType);
103+
boolean finalIsNullable;
104+
switch (function.nullability()) {
105+
case MIRROR:
106+
// If any input is nullable, the output is nullable.
107+
finalIsNullable =
108+
opBinding.collectOperandTypes().stream().anyMatch(RelDataType::isNullable);
109+
break;
110+
case DISCRETE:
111+
// The function can return null even if inputs are not null.
112+
finalIsNullable = true;
113+
break;
114+
case DECLARED_OUTPUT:
115+
default:
116+
// Use the nullability declared on the resolved Substrait type.
117+
finalIsNullable = resolvedSubstraitType.nullable();
118+
break;
119+
}
120+
121+
RelDataType baseCalciteType = typeConverter.toCalcite(typeFactory, resolvedSubstraitType);
122+
123+
return typeFactory.createTypeWithNullability(baseCalciteType, finalIsNullable);
101124
}
102125
}
103126

isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoa
4747

4848
this.operatorTable =
4949
SqlOperatorTables.chain(
50-
SqlOperatorTables.of(generatedDynamicOperators), SubstraitOperatorTable.INSTANCE);
50+
SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators));
5151
}
5252

5353
/**
@@ -94,7 +94,7 @@ public Plan convert(String sql, Prepare.CatalogReader catalogReader) throws SqlP
9494
@VisibleForTesting
9595
List<RelRoot> sqlToRelNode(String sql, Prepare.CatalogReader catalogReader)
9696
throws SqlParseException {
97-
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
97+
SqlValidator validator = new SubstraitSqlValidator(catalogReader, operatorTable);
9898
SqlParser parser = SqlParser.create(sql, parserConfig);
9999
SqlNodeList parsedList = parser.parseStmtList();
100100
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);

isthmus/src/test/java/io/substrait/isthmus/SimpleExtensionToSqlOperatorTest.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class SimpleExtensionToSqlOperatorTest {
1616

1717
@Test
1818
void test() throws IOException {
19-
String customFunctionPath = "/extensions/functions_string_custom.yaml";
19+
String customFunctionPath = "/extensions/scalar_functions_custom.yaml";
2020

2121
SimpleExtension.ExtensionCollection customExtensions =
2222
SimpleExtension.load(
@@ -25,8 +25,6 @@ void test() throws IOException {
2525

2626
List<SqlOperator> operators = SimpleExtensionToSqlOperator.from(customExtensions);
2727

28-
assertEquals(1, operators.size(), "Should generate one operator from the YAML file.");
29-
3028
Optional<SqlOperator> function =
3129
operators.stream()
3230
.filter(op -> op.getName().equalsIgnoreCase("REGEXP_EXTRACT"))

isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
public class UdfSqlSubstraitTest extends PlanTestBase {
1010

11-
private static final String CUSTOM_FUNCTION_PATH = "/extensions/functions_string_custom.yaml";
11+
private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml";
1212

1313
UdfSqlSubstraitTest() {
1414
super(loadExtensions(List.of(CUSTOM_FUNCTION_PATH)));
@@ -17,14 +17,18 @@ public class UdfSqlSubstraitTest extends PlanTestBase {
1717
@Test
1818
public void customUdfTest() throws Exception {
1919

20-
final String[] sql = {
21-
"CREATE TABLE t(x VARCHAR NOT NULL)", "SELECT regexp_extract(x, 'ab') from t"
22-
};
23-
2420
final Prepare.CatalogReader catalogReader =
25-
SubstraitCreateStatementParser.processCreateStatementsToCatalog(sql[0]);
26-
27-
assertSqlSubstraitRelRoundTripWorkaroundOptimizer(sql[1], catalogReader);
21+
SubstraitCreateStatementParser.processCreateStatementsToCatalog(
22+
"CREATE TABLE t(x VARCHAR NOT NULL)");
23+
24+
assertSqlSubstraitRelRoundTripWorkaroundOptimizer(
25+
"SELECT regexp_extract(x, 'ab') from t", catalogReader);
26+
assertSqlSubstraitRelRoundTripWorkaroundOptimizer(
27+
"SELECT format_text('UPPER', x) FROM t", catalogReader);
28+
assertSqlSubstraitRelRoundTripWorkaroundOptimizer(
29+
"SELECT system_property_get(x) FROM t", catalogReader);
30+
assertSqlSubstraitRelRoundTripWorkaroundOptimizer(
31+
"SELECT safe_divide(10,0) FROM t", catalogReader);
2832
}
2933

3034
private static SimpleExtension.ExtensionCollection loadExtensions(

isthmus/src/test/resources/extensions/functions_string_custom.yaml

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
%YAML 1.2
2+
---
3+
scalar_functions:
4+
- name: "regexp_extract"
5+
impls:
6+
- args:
7+
- name: "text"
8+
value: string
9+
- name: "pattern"
10+
value: string
11+
return: string
12+
13+
- name: "format_text"
14+
description: "Formats text based on a mode. The output is nullable if the input is."
15+
impls:
16+
- args:
17+
- name: "mode"
18+
# options: ["UPPER", "LOWER"]
19+
value: string
20+
- name: "input_text"
21+
# options: ["UPPER", "LOWER"]
22+
value: string
23+
return: string
24+
nullability: MIRROR
25+
26+
- name: "system_property_get"
27+
description: "Safely gets a system property. Always returns a nullable string."
28+
impls:
29+
- args:
30+
- name: "property_name"
31+
value: string
32+
return: string?
33+
nullability: DECLARED_OUTPUT
34+
35+
- name: "safe_divide"
36+
description: "Performs division, returning NULL if the denominator is zero."
37+
impls:
38+
- args:
39+
- name: "numerator"
40+
value: i32
41+
- name: "denominator"
42+
value: i32
43+
return: fp32?
44+
nullability: DISCRETE

0 commit comments

Comments
 (0)