Skip to content

Commit b4fbb07

Browse files
committed
feat: add structured UDT literal support with dual encoding
Support both opaque (google.protobuf.Any) and structured (Literal.Struct) encodings for user-defined type literals per Substrait spec. - Split UserDefinedLiteral into UserDefinedAny and UserDefinedStruct - Move type parameters to interface level for parameterized types - Add ExtensionCollector.getExtensionCollection() method - Full Calcite integration with REINTERPRET pattern for Any-based UDTs - Add Scala visitor methods and comprehensive documentation - Comprehensive test coverage including roundtrip tests
1 parent 22448d1 commit b4fbb07

28 files changed

+1245
-77
lines changed

core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
151151
return visitFallback(expr, context);
152152
}
153153

154+
@Override
155+
public O visit(Expression.UserDefinedAny expr, C context) throws E {
156+
return visitFallback(expr, context);
157+
}
158+
159+
@Override
160+
public O visit(Expression.UserDefinedStruct expr, C context) throws E {
161+
return visitFallback(expr, context);
162+
}
163+
154164
@Override
155165
public O visit(Expression.Switch expr, C context) throws E {
156166
return visitFallback(expr, context);

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,21 +662,96 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
662662
}
663663
}
664664

665+
/**
666+
* Base interface for user-defined literals.
667+
*
668+
* <p>User-defined literals can be encoded in one of two ways as per the Substrait spec:
669+
*
670+
* <ul>
671+
* <li>As {@code google.protobuf.Any} - see {@link UserDefinedAny}
672+
* <li>As {@code Literal.Struct} - see {@link UserDefinedStruct}
673+
* </ul>
674+
*
675+
* @see UserDefinedAny
676+
* @see UserDefinedStruct
677+
*/
678+
interface UserDefinedLiteral extends Literal {
679+
String urn();
680+
681+
String name();
682+
683+
List<io.substrait.proto.Type.Parameter> typeParameters();
684+
}
685+
686+
/**
687+
* User-defined literal with value encoded as {@code google.protobuf.Any}.
688+
*
689+
* <p>This encoding allows for arbitrary binary data to be stored in the literal value.
690+
*/
665691
@Value.Immutable
666-
abstract class UserDefinedLiteral implements Literal {
667-
public abstract ByteString value();
692+
abstract class UserDefinedAny implements UserDefinedLiteral {
693+
@Override
694+
public abstract String urn();
695+
696+
@Override
697+
public abstract String name();
698+
699+
@Override
700+
public abstract List<io.substrait.proto.Type.Parameter> typeParameters();
701+
702+
public abstract com.google.protobuf.Any value();
703+
704+
@Override
705+
public Type.UserDefined getType() {
706+
return Type.UserDefined.builder()
707+
.nullable(nullable())
708+
.urn(urn())
709+
.name(name())
710+
.typeParameters(typeParameters())
711+
.build();
712+
}
713+
714+
public static ImmutableExpression.UserDefinedAny.Builder builder() {
715+
return ImmutableExpression.UserDefinedAny.builder();
716+
}
668717

718+
@Override
719+
public <R, C extends VisitationContext, E extends Throwable> R accept(
720+
ExpressionVisitor<R, C, E> visitor, C context) throws E {
721+
return visitor.visit(this, context);
722+
}
723+
}
724+
725+
/**
726+
* User-defined literal with value encoded as {@code Literal.Struct}.
727+
*
728+
* <p>This encoding uses a structured list of fields to represent the literal value.
729+
*/
730+
@Value.Immutable
731+
abstract class UserDefinedStruct implements UserDefinedLiteral {
732+
@Override
669733
public abstract String urn();
670734

735+
@Override
671736
public abstract String name();
672737

673738
@Override
674-
public Type getType() {
675-
return Type.withNullability(nullable()).userDefined(urn(), name());
739+
public abstract List<io.substrait.proto.Type.Parameter> typeParameters();
740+
741+
public abstract List<Literal> fields();
742+
743+
@Override
744+
public Type.UserDefined getType() {
745+
return Type.UserDefined.builder()
746+
.nullable(nullable())
747+
.urn(urn())
748+
.name(name())
749+
.typeParameters(typeParameters())
750+
.build();
676751
}
677752

678-
public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
679-
return ImmutableExpression.UserDefinedLiteral.builder();
753+
public static ImmutableExpression.UserDefinedStruct.Builder builder() {
754+
return ImmutableExpression.UserDefinedStruct.builder();
680755
}
681756

682757
@Override

core/src/main/java/io/substrait/expression/ExpressionCreator.java

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,87 @@ public static Expression.StructLiteral struct(
286286
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
287287
}
288288

289-
public static Expression.UserDefinedLiteral userDefinedLiteral(
289+
/**
290+
* Create a UserDefinedAny with google.protobuf.Any representation.
291+
*
292+
* @param nullable whether the literal is nullable
293+
* @param urn the URN of the user-defined type
294+
* @param name the name of the user-defined type
295+
* @param value the value, encoded as google.protobuf.Any
296+
*/
297+
public static Expression.UserDefinedAny userDefinedLiteralAny(
290298
boolean nullable, String urn, String name, Any value) {
291-
return Expression.UserDefinedLiteral.builder()
299+
return Expression.UserDefinedAny.builder()
300+
.nullable(nullable)
301+
.urn(urn)
302+
.name(name)
303+
.value(value)
304+
.build();
305+
}
306+
307+
/**
308+
* Create a UserDefinedAny with google.protobuf.Any representation and type parameters.
309+
*
310+
* @param nullable whether the literal is nullable
311+
* @param urn the URN of the user-defined type
312+
* @param name the name of the user-defined type
313+
* @param typeParameters the type parameters for the user-defined type
314+
* @param value the value, encoded as google.protobuf.Any
315+
*/
316+
public static Expression.UserDefinedAny userDefinedLiteralAny(
317+
boolean nullable,
318+
String urn,
319+
String name,
320+
java.util.List<io.substrait.proto.Type.Parameter> typeParameters,
321+
Any value) {
322+
return Expression.UserDefinedAny.builder()
323+
.nullable(nullable)
324+
.urn(urn)
325+
.name(name)
326+
.addAllTypeParameters(typeParameters)
327+
.value(value)
328+
.build();
329+
}
330+
331+
/**
332+
* Create a UserDefinedStruct with Struct representation.
333+
*
334+
* @param nullable whether the literal is nullable
335+
* @param urn the URN of the user-defined type
336+
* @param name the name of the user-defined type
337+
* @param fields the fields, as a list of Literal values
338+
*/
339+
public static Expression.UserDefinedStruct userDefinedLiteralStruct(
340+
boolean nullable, String urn, String name, java.util.List<Expression.Literal> fields) {
341+
return Expression.UserDefinedStruct.builder()
342+
.nullable(nullable)
343+
.urn(urn)
344+
.name(name)
345+
.addAllFields(fields)
346+
.build();
347+
}
348+
349+
/**
350+
* Create a UserDefinedStruct with Struct representation and type parameters.
351+
*
352+
* @param nullable whether the literal is nullable
353+
* @param urn the URN of the user-defined type
354+
* @param name the name of the user-defined type
355+
* @param typeParameters the type parameters for the user-defined type
356+
* @param fields the fields, as a list of Literal values
357+
*/
358+
public static Expression.UserDefinedStruct userDefinedLiteralStruct(
359+
boolean nullable,
360+
String urn,
361+
String name,
362+
java.util.List<io.substrait.proto.Type.Parameter> typeParameters,
363+
java.util.List<Expression.Literal> fields) {
364+
return Expression.UserDefinedStruct.builder()
292365
.nullable(nullable)
293366
.urn(urn)
294367
.name(name)
295-
.value(value.toByteString())
368+
.addAllTypeParameters(typeParameters)
369+
.addAllFields(fields)
296370
.build();
297371
}
298372

core/src/main/java/io/substrait/expression/ExpressionVisitor.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
6262

6363
R visit(Expression.StructLiteral expr, C context) throws E;
6464

65-
R visit(Expression.UserDefinedLiteral expr, C context) throws E;
65+
R visit(Expression.UserDefinedAny expr, C context) throws E;
66+
67+
R visit(Expression.UserDefinedStruct expr, C context) throws E;
6668

6769
R visit(Expression.Switch expr, C context) throws E;
6870

core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package io.substrait.expression.proto;
22

3-
import com.google.protobuf.Any;
4-
import com.google.protobuf.InvalidProtocolBufferException;
53
import io.substrait.expression.ExpressionVisitor;
64
import io.substrait.expression.FieldReference;
75
import io.substrait.expression.FunctionArg;
@@ -359,21 +357,40 @@ public Expression visit(
359357

360358
@Override
361359
public Expression visit(
362-
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
360+
io.substrait.expression.Expression.UserDefinedAny expr, EmptyVisitationContext context) {
363361
int typeReference =
364362
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
365363
return lit(
366364
bldr -> {
367-
try {
368-
bldr.setNullable(expr.nullable())
369-
.setUserDefined(
370-
Expression.Literal.UserDefined.newBuilder()
371-
.setTypeReference(typeReference)
372-
.setValue(Any.parseFrom(expr.value())))
373-
.build();
374-
} catch (InvalidProtocolBufferException e) {
375-
throw new IllegalStateException(e);
365+
Expression.Literal.UserDefined.Builder userDefinedBuilder =
366+
Expression.Literal.UserDefined.newBuilder()
367+
.setTypeReference(typeReference)
368+
.addAllTypeParameters(expr.typeParameters())
369+
.setValue(expr.value());
370+
371+
bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
372+
});
373+
}
374+
375+
@Override
376+
public Expression visit(
377+
io.substrait.expression.Expression.UserDefinedStruct expr, EmptyVisitationContext context) {
378+
int typeReference =
379+
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
380+
return lit(
381+
bldr -> {
382+
Expression.Literal.Struct.Builder structBuilder = Expression.Literal.Struct.newBuilder();
383+
for (io.substrait.expression.Expression.Literal field : expr.fields()) {
384+
structBuilder.addFields(toLiteral(field));
376385
}
386+
387+
Expression.Literal.UserDefined.Builder userDefinedBuilder =
388+
Expression.Literal.UserDefined.newBuilder()
389+
.setTypeReference(typeReference)
390+
.addAllTypeParameters(expr.typeParameters())
391+
.setStruct(structBuilder.build());
392+
393+
bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
377394
});
378395
}
379396

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
492492
{
493493
io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral =
494494
literal.getUserDefined();
495+
495496
SimpleExtension.Type type =
496497
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
497-
return ExpressionCreator.userDefinedLiteral(
498-
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
498+
String urn = type.urn();
499+
String name = type.name();
500+
501+
switch (userDefinedLiteral.getValCase()) {
502+
case VALUE:
503+
return ExpressionCreator.userDefinedLiteralAny(
504+
literal.getNullable(),
505+
urn,
506+
name,
507+
userDefinedLiteral.getTypeParametersList(),
508+
userDefinedLiteral.getValue());
509+
case STRUCT:
510+
return ExpressionCreator.userDefinedLiteralStruct(
511+
literal.getNullable(),
512+
urn,
513+
name,
514+
userDefinedLiteral.getTypeParametersList(),
515+
userDefinedLiteral.getStruct().getFieldsList().stream()
516+
.map(this::from)
517+
.collect(Collectors.toList()));
518+
case VAL_NOT_SET:
519+
throw new IllegalStateException(
520+
"UserDefined literal has no value (neither 'value' nor 'struct' is set)");
521+
default:
522+
throw new IllegalStateException(
523+
"Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase());
524+
}
499525
}
500526
default:
501527
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());

core/src/main/java/io/substrait/extension/ExtensionCollector.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,77 @@ public int getTypeReference(SimpleExtension.TypeAnchor typeAnchor) {
6363
return counter;
6464
}
6565

66+
/**
67+
* Returns an ExtensionCollection containing only the types and functions that have been tracked
68+
* by this collector. This provides a minimal collection with exactly what was used during
69+
* serialization.
70+
*
71+
* <p>This collection contains:
72+
*
73+
* <ul>
74+
* <li>Only the types that were referenced via {@link #getTypeReference}
75+
* <li>Only the functions that were referenced via {@link #getFunctionReference}
76+
* <li>URI/URN mappings for only the used extension URNs
77+
* </ul>
78+
*
79+
* <p>Types from the catalog are resolved, while custom UserDefined types (not in the catalog) are
80+
* created via {@link SimpleExtension.Type#of(String, String)}.
81+
*
82+
* @return an ExtensionCollection with only the used types, functions, and URI/URN mappings
83+
*/
84+
public SimpleExtension.ExtensionCollection getExtensionCollection() {
85+
java.util.List<SimpleExtension.Type> types = new ArrayList<>();
86+
java.util.List<SimpleExtension.ScalarFunctionVariant> scalarFunctions = new ArrayList<>();
87+
java.util.List<SimpleExtension.AggregateFunctionVariant> aggregateFunctions = new ArrayList<>();
88+
java.util.List<SimpleExtension.WindowFunctionVariant> windowFunctions = new ArrayList<>();
89+
90+
java.util.Set<String> usedUrns = new java.util.HashSet<>();
91+
92+
for (Map.Entry<Integer, SimpleExtension.TypeAnchor> entry : typeMap.forwardEntrySet()) {
93+
SimpleExtension.TypeAnchor anchor = entry.getValue();
94+
usedUrns.add(anchor.urn());
95+
if (extensionCollection.hasType(anchor)) {
96+
types.add(extensionCollection.getType(anchor));
97+
} else {
98+
types.add(SimpleExtension.Type.of(anchor.urn(), anchor.key()));
99+
}
100+
}
101+
102+
for (Map.Entry<Integer, SimpleExtension.FunctionAnchor> entry : funcMap.forwardEntrySet()) {
103+
SimpleExtension.FunctionAnchor anchor = entry.getValue();
104+
usedUrns.add(anchor.urn());
105+
106+
if (extensionCollection.hasScalarFunction(anchor)) {
107+
scalarFunctions.add(extensionCollection.getScalarFunction(anchor));
108+
} else if (extensionCollection.hasAggregateFunction(anchor)) {
109+
aggregateFunctions.add(extensionCollection.getAggregateFunction(anchor));
110+
} else if (extensionCollection.hasWindowFunction(anchor)) {
111+
windowFunctions.add(extensionCollection.getWindowFunction(anchor));
112+
} else {
113+
throw new IllegalArgumentException(
114+
String.format(
115+
"Function %s::%s was tracked but not found in catalog as scalar, aggregate, or window function",
116+
anchor.urn(), anchor.key()));
117+
}
118+
}
119+
120+
BidiMap<String, String> uriUrnMap = new BidiMap<>();
121+
for (String urn : usedUrns) {
122+
String uri = extensionCollection.getUriFromUrn(urn);
123+
if (uri != null) {
124+
uriUrnMap.put(uri, urn);
125+
}
126+
}
127+
128+
return SimpleExtension.ExtensionCollection.builder()
129+
.addAllTypes(types)
130+
.addAllScalarFunctions(scalarFunctions)
131+
.addAllAggregateFunctions(aggregateFunctions)
132+
.addAllWindowFunctions(windowFunctions)
133+
.uriUrnMap(uriUrnMap)
134+
.build();
135+
}
136+
66137
public void addExtensionsToPlan(Plan.Builder builder) {
67138
SimpleExtensions simpleExtensions = getExtensions();
68139

0 commit comments

Comments
 (0)