diff --git a/lib/Dialect/Comb/Transforms/AssumeTwoValued.cpp b/lib/Dialect/Comb/Transforms/AssumeTwoValued.cpp index c94a55faaa34..ed6a29cbd3d8 100644 --- a/lib/Dialect/Comb/Transforms/AssumeTwoValued.cpp +++ b/lib/Dialect/Comb/Transforms/AssumeTwoValued.cpp @@ -40,10 +40,24 @@ struct ICmpOpConversion : OpRewritePattern { return failure(); } rewriter.replaceOpWithNewOp(op, newPredicate, op.getLhs(), - op.getRhs()); + op.getRhs(), /*twoState=*/true); return success(); } }; + +template +struct AddTwoStateFlag : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (op.getTwoState()) + return failure(); + rewriter.modifyOpInPlace(op, [&] { op.setTwoState(true); }); + return success(); + } +}; + } // namespace namespace { @@ -55,10 +69,18 @@ class AssumeTwoValued : public impl::AssumeTwoValuedBase { }; } // namespace +// Alias for brevity. +template +using TS = AddTwoStateFlag; + void AssumeTwoValued::runOnOperation() { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + + patterns + .add, TS, TS, TS, + TS, TS, TS, TS, TS, TS, + TS, TS, TS, TS, TS>(ctx); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); diff --git a/test/Dialect/Comb/assume-two-valued.mlir b/test/Dialect/Comb/assume-two-valued.mlir index fbe5499d57dd..5793040be8dc 100644 --- a/test/Dialect/Comb/assume-two-valued.mlir +++ b/test/Dialect/Comb/assume-two-valued.mlir @@ -1,7 +1,7 @@ // RUN: circt-opt %s --comb-assume-two-valued | FileCheck %s // CHECK-LABEL: hw.module @ceq -// CHECK-NEXT: [[EQ:%.+]] = comb.icmp eq %a, %b : i1 +// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin eq %a, %b : i1 // CHECK-NEXT: hw.output [[EQ]] : i1 hw.module @ceq(in %a: i1, in %b: i1, out x: i1) { %0 = comb.icmp ceq %a, %b : i1 @@ -9,7 +9,7 @@ hw.module @ceq(in %a: i1, in %b: i1, out x: i1) { } // CHECK-LABEL: hw.module @weq -// CHECK-NEXT: [[EQ:%.+]] = comb.icmp eq %a, %b : i1 +// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin eq %a, %b : i1 // CHECK-NEXT: hw.output [[EQ]] : i1 hw.module @weq(in %a: i1, in %b: i1, out x: i1) { %0 = comb.icmp weq %a, %b : i1 @@ -17,7 +17,7 @@ hw.module @weq(in %a: i1, in %b: i1, out x: i1) { } // CHECK-LABEL: hw.module @cne -// CHECK-NEXT: [[EQ:%.+]] = comb.icmp ne %a, %b : i1 +// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin ne %a, %b : i1 // CHECK-NEXT: hw.output [[EQ]] : i1 hw.module @cne(in %a: i1, in %b: i1, out x: i1) { %0 = comb.icmp cne %a, %b : i1 @@ -25,9 +25,44 @@ hw.module @cne(in %a: i1, in %b: i1, out x: i1) { } // CHECK-LABEL: hw.module @wne -// CHECK-NEXT: [[EQ:%.+]] = comb.icmp ne %a, %b : i1 +// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin ne %a, %b : i1 // CHECK-NEXT: hw.output [[EQ]] : i1 hw.module @wne(in %a: i1, in %b: i1, out x: i1) { %0 = comb.icmp wne %a, %b : i1 hw.output %0 : i1 } + +// CHECK-LABEL: hw.module @and +// CHECK-NEXT: %[[X:.*]] = comb.and bin %a, %b +hw.module @and(in %a: i1, in %b: i1, out x: i1) { + %0 = comb.and %a, %b : i1 + hw.output %0 : i1 +} + +// CHECK-LABEL: hw.module @add +// CHECK-NEXT: %[[X:.*]] = comb.add bin %a, %b +hw.module @add(in %a: i1, in %b: i1, out x: i1) { + %0 = comb.add %a, %b : i1 + hw.output %0 : i1 +} + +// CHECK-LABEL: hw.module @or +// CHECK-NEXT: %[[X:.*]] = comb.or bin %a, %b +hw.module @or(in %a: i1, in %b: i1, out x: i1) { + %0 = comb.or %a, %b : i1 + hw.output %0 : i1 +} + +// CHECK-LABEL: hw.module @sub +// CHECK-NEXT: %[[X:.*]] = comb.sub bin %a, %b +hw.module @sub(in %a: i1, in %b: i1, out x: i1) { + %0 = comb.sub %a, %b : i1 + hw.output %0 : i1 +} + +// CHECK-LABEL: hw.module @xor +// CHECK-NEXT: %[[X:.*]] = comb.xor bin %a, %b +hw.module @xor(in %a: i1, in %b: i1, out x: i1) { + %0 = comb.xor %a, %b : i1 + hw.output %0 : i1 +}