Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ public O visit(Expression.StructLiteral expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.UserDefinedAny expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.UserDefinedStruct expr, C context) throws E {
return visitFallback(expr, context);
}

@Override
public O visit(Expression.Switch expr, C context) throws E {
return visitFallback(expr, context);
Expand Down
87 changes: 81 additions & 6 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -662,21 +662,96 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

/**
* Base interface for user-defined literals.
*
* <p>User-defined literals can be encoded in one of two ways as per the Substrait spec:
*
* <ul>
* <li>As {@code google.protobuf.Any} - see {@link UserDefinedAny}
* <li>As {@code Literal.Struct} - see {@link UserDefinedStruct}
* </ul>
*
* @see UserDefinedAny
* @see UserDefinedStruct
*/
interface UserDefinedLiteral extends Literal {
String urn();

String name();

List<io.substrait.proto.Type.Parameter> typeParameters();
}

/**
* User-defined literal with value encoded as {@code google.protobuf.Any}.
*
* <p>This encoding allows for arbitrary binary data to be stored in the literal value.
*/
@Value.Immutable
abstract class UserDefinedLiteral implements Literal {
public abstract ByteString value();
abstract class UserDefinedAny implements UserDefinedLiteral {
@Override
public abstract String urn();

@Override
public abstract String name();

@Override
public abstract List<io.substrait.proto.Type.Parameter> typeParameters();

public abstract com.google.protobuf.Any value();

@Override
public Type.UserDefined getType() {
return Type.UserDefined.builder()
.nullable(nullable())
.urn(urn())
.name(name())
.typeParameters(typeParameters())
.build();
}

public static ImmutableExpression.UserDefinedAny.Builder builder() {
return ImmutableExpression.UserDefinedAny.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* User-defined literal with value encoded as {@code Literal.Struct}.
*
* <p>This encoding uses a structured list of fields to represent the literal value.
*/
@Value.Immutable
abstract class UserDefinedStruct implements UserDefinedLiteral {
@Override
public abstract String urn();

@Override
public abstract String name();

@Override
public Type getType() {
return Type.withNullability(nullable()).userDefined(urn(), name());
public abstract List<io.substrait.proto.Type.Parameter> typeParameters();

public abstract List<Literal> fields();

@Override
public Type.UserDefined getType() {
return Type.UserDefined.builder()
.nullable(nullable())
.urn(urn())
.name(name())
.typeParameters(typeParameters())
.build();
}

public static ImmutableExpression.UserDefinedLiteral.Builder builder() {
return ImmutableExpression.UserDefinedLiteral.builder();
public static ImmutableExpression.UserDefinedStruct.Builder builder() {
return ImmutableExpression.UserDefinedStruct.builder();
}

@Override
Expand Down
46 changes: 42 additions & 4 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,51 @@ public static Expression.StructLiteral struct(
return Expression.StructLiteral.builder().nullable(nullable).addAllFields(values).build();
}

public static Expression.UserDefinedLiteral userDefinedLiteral(
boolean nullable, String urn, String name, Any value) {
return Expression.UserDefinedLiteral.builder()
/**
* Create a UserDefinedAny with google.protobuf.Any representation.
*
* @param nullable whether the literal is nullable
* @param urn the URN of the user-defined type
* @param name the name of the user-defined type
* @param typeParameters the type parameters for the user-defined type (can be empty list)
* @param value the value, encoded as google.protobuf.Any
*/
public static Expression.UserDefinedAny userDefinedLiteralAny(
boolean nullable,
String urn,
String name,
java.util.List<io.substrait.proto.Type.Parameter> typeParameters,
Any value) {
return Expression.UserDefinedAny.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.addAllTypeParameters(typeParameters)
.value(value)
.build();
}

/**
* Create a UserDefinedStruct with Struct representation.
*
* @param nullable whether the literal is nullable
* @param urn the URN of the user-defined type
* @param name the name of the user-defined type
* @param typeParameters the type parameters for the user-defined type (can be empty list)
* @param fields the fields, as a list of Literal values
*/
public static Expression.UserDefinedStruct userDefinedLiteralStruct(
boolean nullable,
String urn,
String name,
java.util.List<io.substrait.proto.Type.Parameter> typeParameters,
java.util.List<Expression.Literal> fields) {
return Expression.UserDefinedStruct.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.value(value.toByteString())
.addAllTypeParameters(typeParameters)
.addAllFields(fields)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr

R visit(Expression.StructLiteral expr, C context) throws E;

R visit(Expression.UserDefinedLiteral expr, C context) throws E;
R visit(Expression.UserDefinedAny expr, C context) throws E;

R visit(Expression.UserDefinedStruct expr, C context) throws E;

R visit(Expression.Switch expr, C context) throws E;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.substrait.expression.proto;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
Expand Down Expand Up @@ -359,21 +357,40 @@ public Expression visit(

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) {
io.substrait.expression.Expression.UserDefinedAny expr, EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
try {
bldr.setNullable(expr.nullable())
.setUserDefined(
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.setValue(Any.parseFrom(expr.value())))
.build();
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(e);
Expression.Literal.UserDefined.Builder userDefinedBuilder =
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.addAllTypeParameters(expr.typeParameters())
.setValue(expr.value());

bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedStruct expr, EmptyVisitationContext context) {
int typeReference =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return lit(
bldr -> {
Expression.Literal.Struct.Builder structBuilder = Expression.Literal.Struct.newBuilder();
for (io.substrait.expression.Expression.Literal field : expr.fields()) {
structBuilder.addFields(toLiteral(field));
}

Expression.Literal.UserDefined.Builder userDefinedBuilder =
Expression.Literal.UserDefined.newBuilder()
.setTypeReference(typeReference)
.addAllTypeParameters(expr.typeParameters())
.setStruct(structBuilder.build());

bldr.setNullable(expr.nullable()).setUserDefined(userDefinedBuilder).build();
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,36 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
{
io.substrait.proto.Expression.Literal.UserDefined userDefinedLiteral =
literal.getUserDefined();

SimpleExtension.Type type =
lookup.getType(userDefinedLiteral.getTypeReference(), extensions);
return ExpressionCreator.userDefinedLiteral(
literal.getNullable(), type.urn(), type.name(), userDefinedLiteral.getValue());
String urn = type.urn();
String name = type.name();

switch (userDefinedLiteral.getValCase()) {
case VALUE:
return ExpressionCreator.userDefinedLiteralAny(
literal.getNullable(),
urn,
name,
userDefinedLiteral.getTypeParametersList(),
userDefinedLiteral.getValue());
case STRUCT:
return ExpressionCreator.userDefinedLiteralStruct(
literal.getNullable(),
urn,
name,
userDefinedLiteral.getTypeParametersList(),
userDefinedLiteral.getStruct().getFieldsList().stream()
.map(this::from)
.collect(Collectors.toList()));
case VAL_NOT_SET:
throw new IllegalStateException(
"UserDefined literal has no value (neither 'value' nor 'struct' is set)");
default:
throw new IllegalStateException(
"Unknown UserDefined literal value case: " + userDefinedLiteral.getValCase());
}
}
default:
throw new IllegalStateException("Unexpected value: " + literal.getLiteralTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class DefaultExtensionCatalog {
"extension:io.substrait:functions_rounding_decimal";
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";
public static final String EXTENSION_TYPES = "extension:io.substrait:extension_types";

public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
loadDefaultCollection();
Expand All @@ -44,6 +45,8 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
.map(c -> String.format("/functions_%s.yaml", c))
.collect(Collectors.toList());

defaultFiles.add("/extension_types.yaml");

return SimpleExtension.load(defaultFiles);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,15 @@ public Optional<Expression> visit(Expression.StructLiteral expr, EmptyVisitation
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.UserDefinedAny expr, EmptyVisitationContext context)
throws E {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(
Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E {
Expression.UserDefinedStruct expr, EmptyVisitationContext context) throws E {
return visitLiteral(expr);
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/type/Type.java
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ abstract class UserDefined implements Type {

public abstract String name();

/**
* Returns the type parameters for this user-defined type.
*
* <p>Type parameters are used to represent parameterized/generic types, such as {@code
* List<i32>} or {@code Map<String, i64>}. Each parameter in the list represents a type argument
* that specializes the generic user-defined type.
*
* <p>For example, a user-defined type {@code MyList} parameterized by {@code i32} would have
* one type parameter containing the {@code i32} type definition.
*
* @return a list of type parameters, or an empty list if this type is not parameterized
*/
@Value.Default
public java.util.List<io.substrait.proto.Type.Parameter> typeParameters() {
return java.util.Collections.emptyList();
}

public static ImmutableType.UserDefined.Builder builder() {
return ImmutableType.UserDefined.builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,6 @@ public final T visit(final Type.Map expr) {
public final T visit(final Type.UserDefined expr) {
int ref =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return typeContainer(expr).userDefined(ref);
return typeContainer(expr).userDefined(ref, expr.typeParameters());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ public final T struct(T... types) {

public abstract T userDefined(int ref);

public abstract T userDefined(
int ref, java.util.List<io.substrait.proto.Type.Parameter> typeParameters);

protected abstract T wrap(Object o);

protected abstract I i(int integerValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ public ParameterizedType userDefined(int ref) {
"User defined types are not supported in Parameterized Types for now");
}

@Override
public ParameterizedType userDefined(
int ref, java.util.List<io.substrait.proto.Type.Parameter> typeParameters) {
throw new UnsupportedOperationException(
"User defined types are not supported in Parameterized Types for now");
}

@Override
protected ParameterizedType wrap(final Object o) {
ParameterizedType.Builder bldr = ParameterizedType.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,13 @@ public Type from(io.substrait.proto.Type type) {
{
io.substrait.proto.Type.UserDefined userDefined = type.getUserDefined();
SimpleExtension.Type t = lookup.getType(userDefined.getTypeReference(), extensions);
return n(userDefined.getNullability()).userDefined(t.urn(), t.name());
boolean nullable = isNullable(userDefined.getNullability());
return io.substrait.type.Type.UserDefined.builder()
.nullable(nullable)
.urn(t.urn())
.name(t.name())
.typeParameters(userDefined.getTypeParametersList())
.build();
}
case USER_DEFINED_TYPE_REFERENCE:
throw new UnsupportedOperationException("Unsupported user defined reference: " + type);
Expand Down
Loading
Loading