Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions lib/Dialect/Comb/Transforms/AssumeTwoValued.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,40 @@ struct ICmpOpConversion : OpRewritePattern<ICmpOp> {
return failure();
}
rewriter.replaceOpWithNewOp<ICmpOp>(op, newPredicate, op.getLhs(),
op.getRhs());
op.getRhs(), /*twoState=*/true);
return success();
}
};

template <typename OpTy>
struct AddTwoStateFlag : OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
if (op.getTwoState())
return failure();
rewriter.modifyOpInPlace(op, [&] { op.setTwoState(true); });
return success();
}
};

using AddOpTwoStateFlag = AddTwoStateFlag<AddOp>;
using AndOpTwoStateFlag = AddTwoStateFlag<AndOp>;
using DivSOpTwoStateFlag = AddTwoStateFlag<DivSOp>;
using DivUOpTwoStateFlag = AddTwoStateFlag<DivUOp>;
using ModSOpTwoStateFlag = AddTwoStateFlag<ModSOp>;
using ModUOpTwoStateFlag = AddTwoStateFlag<ModUOp>;
using MulOpTwoStateFlag = AddTwoStateFlag<MulOp>;
using MuxOpTwoStateFlag = AddTwoStateFlag<MuxOp>;
using OrOpTwoStateFlag = AddTwoStateFlag<OrOp>;
using ParityOpTwoStateFlag = AddTwoStateFlag<ParityOp>;
using ShlOpTwoStateFlag = AddTwoStateFlag<ShlOp>;
using ShrSOpTwoStateFlag = AddTwoStateFlag<ShrSOp>;
using ShrUOpTwoStateFlag = AddTwoStateFlag<ShrUOp>;
using SubOpTwoStateFlag = AddTwoStateFlag<SubOp>;
using XorOpTwoStateFlag = AddTwoStateFlag<XorOp>;
Comment thread
jmolloy marked this conversation as resolved.
Outdated

} // namespace

namespace {
Expand All @@ -58,7 +88,12 @@ class AssumeTwoValued : public impl::AssumeTwoValuedBase<AssumeTwoValued> {
void AssumeTwoValued::runOnOperation() {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<ICmpOpConversion>(ctx);
patterns.add<ICmpOpConversion, AddOpTwoStateFlag, AndOpTwoStateFlag,
DivSOpTwoStateFlag, DivUOpTwoStateFlag, ModSOpTwoStateFlag,
ModUOpTwoStateFlag, MulOpTwoStateFlag, MuxOpTwoStateFlag,
OrOpTwoStateFlag, ParityOpTwoStateFlag, ShlOpTwoStateFlag,
ShrSOpTwoStateFlag, ShrUOpTwoStateFlag, SubOpTwoStateFlag,
XorOpTwoStateFlag>(ctx);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
Expand Down
43 changes: 39 additions & 4 deletions test/Dialect/Comb/assume-two-valued.mlir
Original file line number Diff line number Diff line change
@@ -1,33 +1,68 @@
// RUN: circt-opt %s --comb-assume-two-valued | FileCheck %s

// CHECK-LABEL: hw.module @ceq
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp eq %a, %b : i1
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin eq %a, %b : i1
// CHECK-NEXT: hw.output [[EQ]] : i1
hw.module @ceq(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.icmp ceq %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @weq
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp eq %a, %b : i1
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin eq %a, %b : i1
// CHECK-NEXT: hw.output [[EQ]] : i1
hw.module @weq(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.icmp weq %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @cne
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp ne %a, %b : i1
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin ne %a, %b : i1
// CHECK-NEXT: hw.output [[EQ]] : i1
hw.module @cne(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.icmp cne %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @wne
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp ne %a, %b : i1
// CHECK-NEXT: [[EQ:%.+]] = comb.icmp bin ne %a, %b : i1
// CHECK-NEXT: hw.output [[EQ]] : i1
hw.module @wne(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.icmp wne %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @and
// CHECK-NEXT: %[[X:.*]] = comb.and bin %a, %b
hw.module @and(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.and %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @add
// CHECK-NEXT: %[[X:.*]] = comb.add bin %a, %b
hw.module @add(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.add %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @or
// CHECK-NEXT: %[[X:.*]] = comb.or bin %a, %b
hw.module @or(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.or %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @sub
// CHECK-NEXT: %[[X:.*]] = comb.sub bin %a, %b
hw.module @sub(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.sub %a, %b : i1
hw.output %0 : i1
}

// CHECK-LABEL: hw.module @xor
// CHECK-NEXT: %[[X:.*]] = comb.xor bin %a, %b
hw.module @xor(in %a: i1, in %b: i1, out x: i1) {
%0 = comb.xor %a, %b : i1
hw.output %0 : i1
}
Loading