Skip to content

Commit 839ba80

Browse files
committed
[OM] Add a folder
1 parent 3f17954 commit 839ba80

4 files changed

Lines changed: 96 additions & 24 deletions

File tree

include/circt/Dialect/OM/OMOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ class IntegerBinaryArithmeticOp<string mnemonic, list<Trait> traits = []> :
460460
let results = (outs OMIntegerType:$result);
461461

462462
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
463+
let hasFolder = 1;
463464
}
464465

465466
def IntegerAddOp : IntegerBinaryArithmeticOp<"integer.add", [Commutative]> {

lib/Dialect/OM/Evaluator/Evaluator.cpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -491,28 +491,15 @@ circt::om::Evaluator::evaluateIntegerBinaryArithmetic(
491491
assert(lhs && rhs &&
492492
"expected om::IntegerAttr for IntegerBinaryArithmeticOp operands");
493493

494-
// Extend values if necessary to match bitwidth. Most interesting arithmetic
495-
// on APSInt asserts that both operands are the same bitwidth, but the
496-
// IntegerAttrs we are working with may have used the smallest necessary
497-
// bitwidth to represent the number they hold, and won't necessarily match.
498-
APSInt lhsVal = lhs.getValue().getAPSInt();
499-
APSInt rhsVal = rhs.getValue().getAPSInt();
500-
if (lhsVal.getBitWidth() > rhsVal.getBitWidth())
501-
rhsVal = rhsVal.extend(lhsVal.getBitWidth());
502-
else if (rhsVal.getBitWidth() > lhsVal.getBitWidth())
503-
lhsVal = lhsVal.extend(rhsVal.getBitWidth());
504-
505-
// Perform arbitrary precision signed integer binary arithmetic.
506-
FailureOr<APSInt> result = op.evaluateIntegerOperation(lhsVal, rhsVal);
507-
508-
if (failed(result))
494+
std::array<Attribute, 2> operandAttrs = {lhs, rhs};
495+
SmallVector<mlir::OpFoldResult, 1> results;
496+
om::IntegerAttr resultAttr;
497+
auto foldResult = op->fold(operandAttrs, results);
498+
if (failed(foldResult) || results.size() != 1 ||
499+
!(resultAttr = llvm::dyn_cast_or_null<om::IntegerAttr>(
500+
results[0].dyn_cast<Attribute>())))
509501
return op->emitError("failed to evaluate integer operation");
510502

511-
// Package the result as a new om::IntegerAttr.
512-
MLIRContext *ctx = op->getContext();
513-
auto resultAttr =
514-
om::IntegerAttr::get(ctx, mlir::IntegerAttr::get(ctx, result.value()));
515-
516503
// Finalize the op result value.
517504
auto *handleValue = cast<evaluator::AttributeValue>(handle.value().get());
518505
auto resultStatus = handleValue->setAttr(resultAttr);

lib/Dialect/OM/OMOps.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "circt/Dialect/OM/OMOps.h"
1414
#include "circt/Dialect/HW/HWOps.h"
15+
#include "circt/Dialect/OM/OMOpInterfaces.h"
1516
#include "circt/Dialect/OM/OMUtils.h"
1617
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/ImplicitLocOpBuilder.h"
@@ -639,6 +640,34 @@ PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
639640
return success();
640641
}
641642

643+
//===----------------------------------------------------------------------===//
644+
// IntegerBinaryArithmeticOp
645+
//===----------------------------------------------------------------------===//
646+
647+
static OpFoldResult foldIntegerBinaryArithmetic(IntegerBinaryArithmeticOp op,
648+
Attribute lhsAttr,
649+
Attribute rhsAttr) {
650+
auto lhs = dyn_cast_or_null<circt::om::IntegerAttr>(lhsAttr);
651+
auto rhs = dyn_cast_or_null<circt::om::IntegerAttr>(rhsAttr);
652+
if (!lhs || !rhs)
653+
return {};
654+
655+
APSInt lhsVal = lhs.getValue().getAPSInt();
656+
APSInt rhsVal = rhs.getValue().getAPSInt();
657+
if (lhsVal.getBitWidth() > rhsVal.getBitWidth())
658+
rhsVal = rhsVal.extend(lhsVal.getBitWidth());
659+
else if (rhsVal.getBitWidth() > lhsVal.getBitWidth())
660+
lhsVal = lhsVal.extend(rhsVal.getBitWidth());
661+
662+
auto result = op.evaluateIntegerOperation(lhsVal, rhsVal);
663+
if (failed(result))
664+
return {};
665+
666+
auto *ctx = op.getContext();
667+
return circt::om::IntegerAttr::get(
668+
ctx, mlir::IntegerAttr::get(ctx, result.value()));
669+
}
670+
642671
//===----------------------------------------------------------------------===//
643672
// IntegerAddOp
644673
//===----------------------------------------------------------------------===//
@@ -649,6 +678,10 @@ IntegerAddOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
649678
return success(lhs + rhs);
650679
}
651680

681+
OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
682+
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
683+
}
684+
652685
//===----------------------------------------------------------------------===//
653686
// IntegerMulOp
654687
//===----------------------------------------------------------------------===//
@@ -659,6 +692,10 @@ IntegerMulOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
659692
return success(lhs * rhs);
660693
}
661694

695+
OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
696+
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
697+
}
698+
662699
//===----------------------------------------------------------------------===//
663700
// IntegerShrOp
664701
//===----------------------------------------------------------------------===//
@@ -675,6 +712,10 @@ IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
675712
return success(lhs >> rhs.getExtValue());
676713
}
677714

715+
OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
716+
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
717+
}
718+
678719
//===----------------------------------------------------------------------===//
679720
// IntegerShlOp
680721
//===----------------------------------------------------------------------===//
@@ -691,14 +732,22 @@ IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
691732
return success(lhs << rhs.getExtValue());
692733
}
693734

735+
OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
736+
return foldIntegerBinaryArithmetic(*this, adaptor.getLhs(), adaptor.getRhs());
737+
}
738+
694739
//===----------------------------------------------------------------------===//
695740
// StringConcatOp
696741
//===----------------------------------------------------------------------===//
697742

698743
OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
699744
// Fold single-operand concat to just the operand.
700-
if (getStrings().size() == 1)
745+
if (getStrings().size() == 1) {
746+
if (auto strAttr = adaptor.getStrings()[0])
747+
return strAttr;
748+
701749
return getStrings()[0];
750+
}
702751

703752
// Check if all operands are constant strings before accumulating.
704753
if (!llvm::all_of(adaptor.getStrings(), [](Attribute operand) {

test/Dialect/OM/canonicalizers.mlir

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func.func @ObjectsMustDCE() {
2323
return
2424
}
2525

26-
om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) -> (out1: !om.string, out2: !om.string, out3: !om.string, out4: !om.string, out5: !om.string, out6: !om.string, out7: !om.string) {
26+
om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) -> (out1: !om.string, out2: !om.string, out3: !om.string, out4: !om.string, out5: !om.string, out6: !om.string, out7: !om.string, out8: !om.string) {
2727
%s1 = om.constant "Hello" : !om.string
2828
%s2 = om.constant "World" : !om.string
2929
%s3 = om.constant "!" : !om.string
@@ -43,6 +43,9 @@ om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) ->
4343
// Single operand replaced with operand
4444
%2 = om.string.concat %s1 : !om.string
4545

46+
// Single constant operand folds to the attribute.
47+
%singleConst = om.string.concat %s3 : !om.string
48+
4649
// Empty concat
4750
%3 = om.string.concat %empty, %empty : !om.string
4851

@@ -57,8 +60,40 @@ om.class @StringConcatCanonicalization(%str1: !om.string, %str2: !om.string) ->
5760
%nested = om.string.concat %str1, %str2 : !om.string
5861
%concat1 = om.string.concat %nested, %s3 : !om.string
5962

60-
// CHECK: om.class.fields [[HELLOWORLD]], [[HELLO]], [[HELLO]], [[EMPTY]], [[HELLOWORLD]], [[CONCAT1]], [[NESTED]]
61-
om.class.fields %0, %1, %2, %3, %5, %concat1, %nested : !om.string, !om.string, !om.string, !om.string, !om.string, !om.string, !om.string
63+
// CHECK: om.class.fields [[HELLOWORLD]], [[HELLO]], [[HELLO]], [[CONST]], [[EMPTY]], [[HELLOWORLD]], [[CONCAT1]], [[NESTED]]
64+
om.class.fields %0, %1, %2, %singleConst, %3, %5, %concat1, %nested : !om.string, !om.string, !om.string, !om.string, !om.string, !om.string, !om.string, !om.string
65+
}
66+
67+
// CHECK-LABEL: @IntegerBinaryArithmeticFold
68+
om.class @IntegerBinaryArithmeticFold(%x: !om.integer) -> (out1: !om.integer, out2: !om.integer,
69+
out3: !om.integer, out4: !om.integer,
70+
out5: !om.integer, out6: !om.integer) {
71+
%i3 = om.constant #om.integer<3 : si4> : !om.integer
72+
%i4 = om.constant #om.integer<4 : si4> : !om.integer
73+
%i2 = om.constant #om.integer<2 : si4> : !om.integer
74+
%neg1 = om.constant #om.integer<-1 : si4> : !om.integer
75+
%i1 = om.constant #om.integer<1 : si4> : !om.integer
76+
%wide = om.constant #om.integer<7 : si6> : !om.integer
77+
78+
// CHECK-DAG: [[ADD:%.+]] = om.constant #om.integer<7 : si4> : !om.integer
79+
// CHECK-DAG: [[MUL:%.+]] = om.constant #om.integer<-4 : si4> : !om.integer
80+
// CHECK-DAG: [[SHR:%.+]] = om.constant #om.integer<1 : si4> : !om.integer
81+
// CHECK-DAG: [[SHL:%.+]] = om.constant #om.integer<-2 : si4> : !om.integer
82+
// CHECK-DAG: [[WIDEADD:%.+]] = om.constant #om.integer<9 : si6> : !om.integer
83+
// CHECK: [[DYN:%.+]] = om.integer.add %x, %{{.+}} : !om.integer
84+
%0 = om.integer.add %i3, %i4 : !om.integer
85+
%1 = om.integer.mul %i3, %i4 : !om.integer
86+
%2 = om.integer.shr %i4, %i2 : !om.integer
87+
%3 = om.integer.shl %neg1, %i1 : !om.integer
88+
89+
// Mixed bit widths should still fold after extending operands.
90+
%4 = om.integer.add %i2, %wide : !om.integer
91+
92+
// Non-constant operands should remain.
93+
%5 = om.integer.add %x, %i1 : !om.integer
94+
95+
// CHECK: om.class.fields [[ADD]], [[MUL]], [[SHR]], [[SHL]], [[WIDEADD]], [[DYN]]
96+
om.class.fields %0, %1, %2, %3, %4, %5 : !om.integer, !om.integer, !om.integer, !om.integer, !om.integer, !om.integer
6297
}
6398

6499
// CHECK-LABEL: @PropEqFold

0 commit comments

Comments
 (0)