Skip to content

Commit 2f8f79a

Browse files
committed
feat(isthmus): udf support for substrait<->calcite
1 parent d61d8c4 commit 2f8f79a

19 files changed

+632
-33
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package io.substrait.isthmus;
2+
3+
import io.substrait.extension.SimpleExtension;
4+
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
5+
import java.util.List;
6+
import java.util.Locale;
7+
import java.util.Set;
8+
import java.util.stream.Collectors;
9+
10+
public class ExtensionUtils {
11+
12+
public static SimpleExtension.ExtensionCollection getCustomExtensions(
13+
SimpleExtension.ExtensionCollection extensions) {
14+
Set<String> knownFunctionNames =
15+
SubstraitOperatorTable.INSTANCE.getOperatorList().stream()
16+
.map(op -> op.getName().toLowerCase(Locale.ROOT))
17+
.collect(Collectors.toSet());
18+
19+
List<SimpleExtension.ScalarFunctionVariant> customFunctions =
20+
extensions.scalarFunctions().stream()
21+
.filter(f -> !knownFunctionNames.contains(f.name().toLowerCase(Locale.ROOT)))
22+
.collect(Collectors.toList());
23+
24+
return SimpleExtension.ExtensionCollection.builder()
25+
.scalarFunctions(customFunctions)
26+
// TODO: handle aggregates and other functions
27+
.build();
28+
}
29+
}

isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
import org.apache.calcite.sql2rel.SqlToRelConverter;
2020

2121
class SqlConverterBase {
22-
protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION =
23-
SimpleExtension.loadDefaults();
24-
22+
protected final SimpleExtension.ExtensionCollection extensionCollection;
2523
final RelDataTypeFactory factory;
2624
final RelOptCluster relOptCluster;
2725
final CalciteConnectionConfig config;
@@ -32,7 +30,8 @@ class SqlConverterBase {
3230
protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
3331
final FeatureBoard featureBoard;
3432

35-
protected SqlConverterBase(FeatureBoard features) {
33+
protected SqlConverterBase(
34+
FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) {
3635
this.factory = new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM);
3736
this.config =
3837
CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false");
@@ -51,5 +50,11 @@ protected SqlConverterBase(FeatureBoard features) {
5150
.withUnquotedCasing(featureBoard.unquotedCasing())
5251
.withParserFactory(SqlDdlParserImpl.FACTORY)
5352
.withConformance(SqlConformanceEnum.LENIENT);
53+
54+
this.extensionCollection = extensionCollection;
55+
}
56+
57+
protected SqlConverterBase(FeatureBoard features) {
58+
this(features, SimpleExtension.loadDefaults());
5459
}
5560
}

isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ public class SqlExpressionToSubstrait extends SqlConverterBase {
3434
protected final RexExpressionConverter rexConverter;
3535

3636
public SqlExpressionToSubstrait() {
37-
this(FEATURES_DEFAULT, EXTENSION_COLLECTION);
37+
this(FEATURES_DEFAULT, SimpleExtension.loadDefaults());
3838
}
3939

4040
public SqlExpressionToSubstrait(
4141
FeatureBoard features, SimpleExtension.ExtensionCollection extensions) {
42-
super(features);
42+
super(features, extensions);
4343
ScalarFunctionConverter scalarFunctionConverter =
4444
new ScalarFunctionConverter(extensions.scalarFunctions(), factory);
4545
this.rexConverter = new RexExpressionConverter(scalarFunctionConverter);

isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package io.substrait.isthmus;
22

33
import com.google.common.annotations.VisibleForTesting;
4+
import io.substrait.extension.SimpleExtension;
5+
import io.substrait.isthmus.calcite.SubstraitOperatorTable;
46
import io.substrait.isthmus.sql.SubstraitSqlValidator;
57
import io.substrait.plan.ImmutablePlan.Builder;
68
import io.substrait.plan.Plan.Version;
@@ -14,31 +16,48 @@
1416
import org.apache.calcite.rel.rules.CoreRules;
1517
import org.apache.calcite.sql.SqlNode;
1618
import org.apache.calcite.sql.SqlNodeList;
19+
import org.apache.calcite.sql.SqlOperator;
20+
import org.apache.calcite.sql.SqlOperatorTable;
1721
import org.apache.calcite.sql.parser.SqlParseException;
1822
import org.apache.calcite.sql.parser.SqlParser;
23+
import org.apache.calcite.sql.util.SqlOperatorTables;
1924
import org.apache.calcite.sql.validate.SqlValidator;
2025
import org.apache.calcite.sql2rel.SqlToRelConverter;
2126
import org.apache.calcite.sql2rel.StandardConvertletTable;
2227

2328
/** Take a SQL statement and a set of table definitions and return a substrait plan. */
2429
public class SqlToSubstrait extends SqlConverterBase {
30+
private final SqlOperatorTable operatorTable;
2531

2632
public SqlToSubstrait() {
27-
this(null);
33+
this(SimpleExtension.loadDefaults(), null);
2834
}
2935

3036
public SqlToSubstrait(FeatureBoard features) {
31-
super(features);
37+
this(SimpleExtension.loadDefaults(), features);
38+
}
39+
40+
public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
41+
super(features, extensions);
42+
43+
SimpleExtension.ExtensionCollection customExtensionCollection =
44+
ExtensionUtils.getCustomExtensions(extensions);
45+
List<SqlOperator> generatedCustomOperators =
46+
YamlToSqlOperator.from(customExtensionCollection, this.factory);
47+
48+
this.operatorTable =
49+
SqlOperatorTables.chain(
50+
SqlOperatorTables.of(generatedCustomOperators), SubstraitOperatorTable.INSTANCE);
3251
}
3352

3453
public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException {
35-
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
54+
SqlValidator validator = new SubstraitSqlValidator(this.operatorTable, catalogReader);
3655
return executeInner(sql, validator, catalogReader);
3756
}
3857

3958
List<RelRoot> sqlToRelNode(String sql, Prepare.CatalogReader catalogReader)
4059
throws SqlParseException {
41-
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
60+
SqlValidator validator = new SubstraitSqlValidator(this.operatorTable, catalogReader);
4261
return sqlToRelNode(sql, validator, catalogReader);
4362
}
4463

@@ -49,7 +68,7 @@ private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogRea
4968

5069
// TODO: consider case in which one sql passes conversion while others don't
5170
sqlToRelNode(sql, validator, catalogReader).stream()
52-
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
71+
.map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard))
5372
.forEach(root -> builder.addRoots(root));
5473

5574
PlanProtoConverter planToProto = new PlanProtoConverter();

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package io.substrait.isthmus;
22

3-
import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
4-
53
import com.google.common.collect.ImmutableList;
64
import com.google.common.collect.Range;
75
import com.google.common.collect.RangeMap;
@@ -12,6 +10,7 @@
1210
import io.substrait.extension.SimpleExtension;
1311
import io.substrait.isthmus.expression.AggregateFunctionConverter;
1412
import io.substrait.isthmus.expression.ExpressionRexConverter;
13+
import io.substrait.isthmus.expression.FunctionMappings;
1514
import io.substrait.isthmus.expression.ScalarFunctionConverter;
1615
import io.substrait.isthmus.expression.WindowFunctionConverter;
1716
import io.substrait.relation.AbstractRelVisitor;
@@ -97,7 +96,7 @@ public SubstraitRelNodeConverter(
9796
this(
9897
typeFactory,
9998
relBuilder,
100-
new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory),
99+
createScalarFunctionConverter(extensions, typeFactory),
101100
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory),
102101
new WindowFunctionConverter(extensions.windowFunctions(), typeFactory),
103102
TypeConverter.DEFAULT);
@@ -139,11 +138,45 @@ public SubstraitRelNodeConverter(
139138
this.expressionRexConverter.setRelNodeConverter(this);
140139
}
141140

141+
private static ScalarFunctionConverter createScalarFunctionConverter(
142+
SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) {
143+
144+
java.util.Set<String> knownFunctionNames =
145+
FunctionMappings.SCALAR_SIGS.stream()
146+
.map(FunctionMappings.Sig::name)
147+
.collect(Collectors.toSet());
148+
149+
List<SimpleExtension.ScalarFunctionVariant> customFunctions =
150+
extensions.scalarFunctions().stream()
151+
.filter(f -> !knownFunctionNames.contains(f.name().toLowerCase()))
152+
.collect(Collectors.toList());
153+
154+
List<FunctionMappings.Sig> additionalSignatures;
155+
if (customFunctions.isEmpty()) {
156+
additionalSignatures = Collections.emptyList();
157+
} else {
158+
SimpleExtension.ExtensionCollection customExtensionCollection =
159+
SimpleExtension.ExtensionCollection.builder().scalarFunctions(customFunctions).build();
160+
161+
List<SqlOperator> customOperators =
162+
YamlToSqlOperator.from(customExtensionCollection, typeFactory);
163+
164+
additionalSignatures =
165+
customOperators.stream()
166+
.map(op -> FunctionMappings.s(op, op.getName()))
167+
.collect(Collectors.toList());
168+
}
169+
170+
return new ScalarFunctionConverter(
171+
extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT);
172+
}
173+
142174
public static RelNode convert(
143175
Rel relRoot,
144176
RelOptCluster relOptCluster,
145177
Prepare.CatalogReader catalogReader,
146-
SqlParser.Config parserConfig) {
178+
SqlParser.Config parserConfig,
179+
SimpleExtension.ExtensionCollection extensions) {
147180
RelBuilder relBuilder =
148181
RelBuilder.create(
149182
Frameworks.newConfigBuilder()
@@ -154,8 +187,7 @@ public static RelNode convert(
154187
.build());
155188

156189
return relRoot.accept(
157-
new SubstraitRelNodeConverter(
158-
EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder),
190+
new SubstraitRelNodeConverter(extensions, relOptCluster.getTypeFactory(), relBuilder),
159191
Context.newContext());
160192
}
161193

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import io.substrait.extension.SimpleExtension;
99
import io.substrait.isthmus.expression.AggregateFunctionConverter;
1010
import io.substrait.isthmus.expression.CallConverters;
11+
import io.substrait.isthmus.expression.FunctionMappings;
1112
import io.substrait.isthmus.expression.LiteralConverter;
1213
import io.substrait.isthmus.expression.RexExpressionConverter;
1314
import io.substrait.isthmus.expression.ScalarFunctionConverter;
@@ -53,6 +54,7 @@
5354
import org.apache.calcite.rex.RexBuilder;
5455
import org.apache.calcite.rex.RexFieldAccess;
5556
import org.apache.calcite.rex.RexNode;
57+
import org.apache.calcite.sql.SqlOperator;
5658
import org.apache.calcite.util.ImmutableBitSet;
5759
import org.immutables.value.Value;
5860

@@ -71,17 +73,35 @@ public class SubstraitRelVisitor extends RelNodeVisitor<Rel, RuntimeException> {
7173

7274
public SubstraitRelVisitor(
7375
RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) {
74-
this(typeFactory, extensions, FEATURES_DEFAULT);
76+
this(
77+
typeFactory,
78+
extensions,
79+
FEATURES_DEFAULT);
7580
}
7681

7782
public SubstraitRelVisitor(
7883
RelDataTypeFactory typeFactory,
7984
SimpleExtension.ExtensionCollection extensions,
8085
FeatureBoard features) {
86+
87+
SimpleExtension.ExtensionCollection customExtensionCollection =
88+
ExtensionUtils.getCustomExtensions(extensions);
89+
List<SqlOperator> customOperators =
90+
YamlToSqlOperator.from(customExtensionCollection, typeFactory);
91+
92+
List<FunctionMappings.Sig> additionalSignatures =
93+
customOperators.stream()
94+
.map(op -> FunctionMappings.s(op, op.getName()))
95+
.collect(Collectors.toList());
8196
this.typeConverter = TypeConverter.DEFAULT;
82-
ArrayList<CallConverter> converters = new ArrayList<CallConverter>();
97+
ArrayList<CallConverter> converters = new ArrayList<>();
8398
converters.addAll(CallConverters.defaults(typeConverter));
84-
converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory));
99+
converters.add(
100+
new ScalarFunctionConverter(
101+
extensions.scalarFunctions(),
102+
additionalSignatures,
103+
typeFactory,
104+
TypeConverter.DEFAULT));
85105
converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
86106
this.aggregateFunctionConverter =
87107
new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);

isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package io.substrait.isthmus;
22

3+
import io.substrait.extension.SimpleExtension;
34
import io.substrait.relation.Rel;
5+
import java.util.List;
46
import org.apache.calcite.prepare.Prepare;
57
import org.apache.calcite.rel.RelNode;
68

@@ -10,7 +12,21 @@ public SubstraitToSql() {
1012
super(FEATURES_DEFAULT);
1113
}
1214

15+
public SubstraitToSql(List<String> yamlFunctionFiles) {
16+
super(FEATURES_DEFAULT, loadExtensions(yamlFunctionFiles));
17+
}
18+
19+
private static SimpleExtension.ExtensionCollection loadExtensions(
20+
List<String> yamlFunctionFiles) {
21+
SimpleExtension.ExtensionCollection allExtensions = SimpleExtension.loadDefaults();
22+
if (yamlFunctionFiles != null && !yamlFunctionFiles.isEmpty()) {
23+
allExtensions = allExtensions.merge(SimpleExtension.load(yamlFunctionFiles));
24+
}
25+
return allExtensions;
26+
}
27+
1328
public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) {
14-
return SubstraitRelNodeConverter.convert(relRoot, relOptCluster, catalog, parserConfig);
29+
return SubstraitRelNodeConverter.convert(
30+
relRoot, relOptCluster, catalog, parserConfig, extensionCollection);
1531
}
1632
}

0 commit comments

Comments
 (0)