Skip to content

Commit 3eaa9cf

Browse files
committed
refactor importLiteral into importAttribute
1 parent ac24aac commit 3eaa9cf

File tree

2 files changed

+40
-34
lines changed

2 files changed

+40
-34
lines changed

lib/Target/SubstraitPB/Import.cpp

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -640,83 +640,84 @@ static mlir::FailureOr<JoinOp> importJoinRel(ImplicitLocOpBuilder builder,
640640
return joinOp;
641641
}
642642

643-
static mlir::FailureOr<LiteralOp>
644-
importLiteral(ImplicitLocOpBuilder builder,
645-
const Expression::Literal &message) {
643+
static mlir::FailureOr<mlir::Attribute>
644+
importAttribute(ImplicitLocOpBuilder builder,
645+
const Expression::Literal &message) {
646646
MLIRContext *context = builder.getContext();
647647
Location loc = builder.getLoc();
648648

649649
Expression::Literal::LiteralTypeCase literalType =
650650
message.literal_type_case();
651+
651652
switch (literalType) {
652653
case Expression::Literal::LiteralTypeCase::kBoolean: {
653654
auto attr = IntegerAttr::get(
654655
IntegerType::get(context, 1, IntegerType::Signed), message.boolean());
655-
return builder.create<LiteralOp>(attr);
656+
return attr;
656657
}
657658
case Expression::Literal::LiteralTypeCase::kI8: {
658659
auto attr = IntegerAttr::get(
659660
IntegerType::get(context, 8, IntegerType::Signed), message.i8());
660-
return builder.create<LiteralOp>(attr);
661+
return attr;
661662
}
662663
case Expression::Literal::LiteralTypeCase::kI16: {
663664
auto attr = IntegerAttr::get(
664665
IntegerType::get(context, 16, IntegerType::Signed), message.i16());
665-
return builder.create<LiteralOp>(attr);
666+
return attr;
666667
}
667668
case Expression::Literal::LiteralTypeCase::kI32: {
668669
auto attr = IntegerAttr::get(
669670
IntegerType::get(context, 32, IntegerType::Signed), message.i32());
670-
return builder.create<LiteralOp>(attr);
671+
return attr;
671672
}
672673
case Expression::Literal::LiteralTypeCase::kI64: {
673674
auto attr = IntegerAttr::get(
674675
IntegerType::get(context, 64, IntegerType::Signed), message.i64());
675-
return builder.create<LiteralOp>(attr);
676+
return attr;
676677
}
677678
case Expression::Literal::LiteralTypeCase::kFp32: {
678679
auto attr = FloatAttr::get(Float32Type::get(context), message.fp32());
679-
return builder.create<LiteralOp>(attr);
680+
return attr;
680681
}
681682
case Expression::Literal::LiteralTypeCase::kFp64: {
682683
auto attr = FloatAttr::get(Float64Type::get(context), message.fp64());
683-
return builder.create<LiteralOp>(attr);
684+
return attr;
684685
}
685686
case Expression::Literal::LiteralTypeCase::kString: {
686687
auto attr = StringAttr::get(message.string(), StringType::get(context));
687-
return builder.create<LiteralOp>(attr);
688+
return attr;
688689
}
689690
case Expression::Literal::LiteralTypeCase::kBinary: {
690691
auto attr = StringAttr::get(message.binary(), BinaryType::get(context));
691-
return builder.create<LiteralOp>(attr);
692+
return attr;
692693
}
693694
case Expression::Literal::LiteralTypeCase::kTimestamp: {
694695
auto attr = TimestampAttr::get(context, message.timestamp());
695-
return builder.create<LiteralOp>(attr);
696+
return attr;
696697
}
697698
case Expression::Literal::LiteralTypeCase::kTimestampTz: {
698699
auto attr = TimestampTzAttr::get(context, message.timestamp_tz());
699-
return builder.create<LiteralOp>(attr);
700+
return attr;
700701
}
701702
case Expression::Literal::LiteralTypeCase::kDate: {
702703
auto attr = DateAttr::get(context, message.date());
703-
return builder.create<LiteralOp>(attr);
704+
return attr;
704705
}
705706
case Expression::Literal::LiteralTypeCase::kTime: {
706707
auto attr = TimeAttr::get(context, message.time());
707-
return builder.create<LiteralOp>(attr);
708+
return attr;
708709
}
709710
case Expression::Literal::LiteralTypeCase::kIntervalYearToMonth: {
710711
auto attr = IntervalYearMonthAttr::get(
711712
context, message.interval_year_to_month().years(),
712713
message.interval_year_to_month().months());
713-
return builder.create<LiteralOp>(attr);
714+
return attr;
714715
}
715716
case Expression::Literal::LiteralTypeCase::kIntervalDayToSecond: {
716717
auto attr = IntervalDaySecondAttr::get(
717718
context, message.interval_day_to_second().days(),
718719
message.interval_day_to_second().seconds());
719-
return builder.create<LiteralOp>(attr);
720+
return attr;
720721
}
721722
case Expression::Literal::LiteralTypeCase::kUuid: {
722723
APInt var(128, 0);
@@ -725,29 +726,29 @@ importLiteral(ImplicitLocOpBuilder builder,
725726
IntegerAttr integer_attr =
726727
IntegerAttr::get(IntegerType::get(context, 128), var);
727728
auto attr = UUIDAttr::get(context, integer_attr);
728-
return builder.create<LiteralOp>(attr);
729+
return attr;
729730
}
730731
case Expression::Literal::LiteralTypeCase::kFixedChar: {
731732
StringAttr stringAttr = StringAttr::get(context, message.fixed_char());
732733
FixedCharType fixedCharType =
733734
FixedCharType::get(context, message.fixed_char().size());
734735
auto attr = FixedCharAttr::get(context, stringAttr, fixedCharType);
735-
return builder.create<LiteralOp>(attr);
736+
return attr;
736737
}
737738
case Expression::Literal::LiteralTypeCase::kVarChar: {
738739
StringAttr stringAttr =
739740
StringAttr::get(context, message.var_char().value());
740741
VarCharType varCharType =
741742
VarCharType::get(context, message.var_char().value().size());
742743
auto attr = VarCharAttr::get(context, stringAttr, varCharType);
743-
return builder.create<LiteralOp>(attr);
744+
return attr;
744745
}
745746
case Expression::Literal::LiteralTypeCase::kFixedBinary: {
746747
StringAttr stringAttr = StringAttr::get(context, message.fixed_binary());
747748
FixedBinaryType fixedBinaryType =
748749
FixedBinaryType::get(context, message.fixed_binary().size());
749750
auto attr = FixedBinaryAttr::get(context, stringAttr, fixedBinaryType);
750-
return builder.create<LiteralOp>(attr);
751+
return attr;
751752
}
752753
case Expression::Literal::LiteralTypeCase::kDecimal: {
753754
APInt var(128, 0);
@@ -759,19 +760,18 @@ importLiteral(ImplicitLocOpBuilder builder,
759760
message.decimal().scale());
760761
IntegerAttr value = IntegerAttr::get(IntegerType::get(context, 128), var);
761762
auto attr = DecimalAttr::get(context, type, value);
762-
return builder.create<LiteralOp>(attr);
763+
return attr;
763764
}
764765
case Expression::Literal::LiteralTypeCase::kList: {
765766
const Expression::Literal::List &listType = message.list();
766767
llvm::SmallVector<Attribute> listElements;
767768
listElements.reserve(listType.values_size());
768769
for (const Expression_Literal &element : listType.values()) {
769-
// TODO: Create importAttribute function to avoid creating redundant
770-
// 'LiteralOp's, as seen in test cases.
771-
FailureOr<LiteralOp> elementOp = importLiteral(builder, element);
772-
if (failed(elementOp))
770+
mlir::FailureOr<mlir::Attribute> elementAttr =
771+
importAttribute(builder, element);
772+
if (failed(elementAttr))
773773
return failure();
774-
listElements.push_back(elementOp->getValue());
774+
listElements.push_back(*elementAttr);
775775
}
776776
// Get the list type from the first element.
777777
auto listElementType =
@@ -780,7 +780,7 @@ importLiteral(ImplicitLocOpBuilder builder,
780780
auto listAttr = ArrayAttr::get(context, listElements);
781781
// TODO: Adjust when nullability is implemented. (List could be empty.)
782782
auto attr = ListAttr::get(context, listAttr, type);
783-
return builder.create<LiteralOp>(attr);
783+
return attr;
784784
}
785785

786786
// TODO(ingomueller): Support more types.
@@ -792,6 +792,15 @@ importLiteral(ImplicitLocOpBuilder builder,
792792
}
793793
}
794794

795+
static mlir::FailureOr<LiteralOp>
796+
importLiteral(ImplicitLocOpBuilder builder,
797+
const Expression::Literal &message) {
798+
mlir::FailureOr<mlir::Attribute> attr = importAttribute(builder, message);
799+
if (failed(attr))
800+
return failure();
801+
return builder.create<LiteralOp>(*attr);
802+
}
803+
795804
static mlir::FailureOr<FetchOp> importFetchRel(ImplicitLocOpBuilder builder,
796805
const Rel &message) {
797806
const FetchRel &fetchRel = message.fetch();

test/Target/SubstraitPB/Import/literal.textpb

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
# CHECK-NEXT: %[[V0:.*]] = named_table
1616
# CHECK-NEXT: %[[V1:.*]] = project %[[V0]] : tuple<si1> -> tuple<si1, !substrait.list<!substrait.fixed_binary<4>>> {
1717
# CHECK-NEXT: ^[[BB0:.*]](%[[ARG0:.*]]: tuple<si1>):
18-
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>
19-
# CHECK-NEXT: %[[V3:.*]] = literal #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>
20-
# CHECK-NEXT: %[[V4:.*]] = literal #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>
21-
# CHECK-NEXT: %[[V5:.*]] = literal #substrait.list<[#substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>, #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>, #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>], <!substrait.fixed_binary<4>>> : !substrait.list<!substrait.fixed_binary<4>>
22-
# CHECK-NEXT: yield %[[V5]] : !substrait.list<!substrait.fixed_binary<4>>
18+
# CHECK-NEXT: %[[V2:.*]] = literal #substrait.list<[#substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>, #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>, #substrait.fixed_binary<"8181"> : !substrait.fixed_binary<4>], <!substrait.fixed_binary<4>>> : !substrait.list<!substrait.fixed_binary<4>>
19+
# CHECK-NEXT: yield %[[V2]] : !substrait.list<!substrait.fixed_binary<4>>
2320
# CHECK-NEXT: }
2421
# CHECK-NEXT: yield %[[V1]] : tuple<si1, !substrait.list<!substrait.fixed_binary<4>>>
2522

0 commit comments

Comments
 (0)