Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/circt/Dialect/Synth/Transforms/SynthPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def SOPBalancing : CutRewriterPassBase<"synth-sop-balancing", "hw::HWModuleOp">
let dependentDialects = ["synth::SynthDialect", "comb::CombDialect"];
}

def FunctionalReduction : Pass<"synth-functional-reduction", "hw::HWModuleOp"> {
def FunctionalReduction : Pass<"synth-functional-reduction"> {
let summary = "Functional reduction by simulation-guided candidate merging";
let description = [{
This pass performs functional reduction for single-bit `synth` logic. It
Expand Down
64 changes: 42 additions & 22 deletions lib/Dialect/Synth/Transforms/FunctionalReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ void FunctionalReductionSATBuilder::encodeValue(Value value) {

class FunctionalReductionSolver {
public:
FunctionalReductionSolver(hw::HWModuleOp module, unsigned numPatterns,
FunctionalReductionSolver(Operation *module, unsigned numPatterns,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to pass Block* block instead of Operaton* module, since it's very tedious to run this transformation across different blocks.

unsigned seed, bool testTransformation,
std::unique_ptr<IncrementalSATSolver> satSolver)
: module(module), numPatterns(numPatterns), seed(seed),
Expand Down Expand Up @@ -334,8 +334,8 @@ class FunctionalReductionSolver {
static bool matchesTestEquivClass(Value lhs, Value rhs);
EquivResult verifyEquivalence(Value lhs, Value rhs, bool inverted);

// Module being processed
hw::HWModuleOp module;
// Operation being processed
Operation *module;

// Configuration
unsigned numPatterns;
Expand Down Expand Up @@ -425,18 +425,23 @@ void FunctionalReductionSolver::initializeSATState() {
//===----------------------------------------------------------------------===//

void FunctionalReductionSolver::collectValues() {
// Collect block arguments (primary inputs) that are i1
for (auto arg : module.getBodyBlock()->getArguments()) {
if (arg.getType().isInteger(1)) {
primaryInputs.push_back(arg);
allValues.push_back(arg);
// Collect block arguments of the direct regions of `module` as primary
// inputs. These are the "ports" of the logic network being analyzed.
for (auto &region : module->getRegions()) {
for (auto &block : region) {
for (auto arg : block.getArguments()) {
if (arg.getType().isInteger(1)) {
primaryInputs.push_back(arg);
allValues.push_back(arg);
}
}
}
}

// Walk operations and collect i1 results
// - AIG operations: add to allValues for simulation
// - AIG/comb operations: add to allValues for simulation
// - Unknown operations: treat as inputs (assign random patterns)
module.walk([&](Operation *op) {
module->walk([&](Operation *op) {
for (auto result : op->getResults()) {
if (!result.getType().isInteger(1))
continue;
Expand All @@ -450,6 +455,23 @@ void FunctionalReductionSolver::collectValues() {
}
});

// Collect any i1 operands of simulatable ops that are defined outside the
// scope of `module` (e.g. block arguments of a nested op when the solver is
// run on a container op). Treat them as additional primary inputs so that
// simulation patterns are assigned to them before propagation.
llvm::DenseSet<Value> allValuesSet(allValues.begin(), allValues.end());
for (auto value : allValues) {
Operation *op = value.getDefiningOp();
if (!op || !isFunctionalReductionSimulatableOp(op))
continue;
for (auto operand : op->getOperands()) {
if (operand.getType().isInteger(1) && !allValuesSet.count(operand)) {
primaryInputs.push_back(operand);
allValuesSet.insert(operand);
}
}
}

LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Collected "
<< primaryInputs.size()
<< " primary inputs (including unknown ops) and "
Expand Down Expand Up @@ -614,7 +636,7 @@ void FunctionalReductionSolver::mergeEquivalentNodes() {
if (provenEquivalences.empty())
return;

mlir::OpBuilder builder(module.getContext());
mlir::OpBuilder builder(module->getContext());
for (auto &provenEquivSet : provenEquivalences) {
auto &[representative, members] = provenEquivSet;
if (members.empty())
Expand Down Expand Up @@ -751,36 +773,34 @@ struct FunctionalReductionPass
}

void runOnOperation() override {
auto module = getOperation();
Operation *op = getOperation();
LLVM_DEBUG(llvm::dbgs() << "Running FunctionalReduction pass on "
<< module.getName() << "\n");
<< op->getName() << "\n");

if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
module.emitError()
op->emitError()
<< "'num-random-patterns' must be a positive multiple of 64";
return signalPassFailure();
}
if (conflictLimit < -1) {
module.emitError()
<< "'conflict-limit' must be greater than or equal to -1";
op->emitError() << "'conflict-limit' must be greater than or equal to -1";
return signalPassFailure();
}

std::unique_ptr<IncrementalSATSolver> satSolver;
if (!testTransformation) {
satSolver = createFunctionalReductionSATSolver(this->satSolver);
if (!satSolver) {
module.emitError() << "unsupported or unavailable SAT solver '"
<< this->satSolver
<< "' (expected auto, z3, or cadical)";
op->emitError() << "unsupported or unavailable SAT solver '"
<< this->satSolver
<< "' (expected auto, z3, or cadical)";
return signalPassFailure();
}
satSolver->setConflictLimit(static_cast<int>(conflictLimit));
}

FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
testTransformation,
std::move(satSolver));
FunctionalReductionSolver fcSolver(
op, numRandomPatterns, seed, testTransformation, std::move(satSolver));
auto stats = fcSolver.run();
if (failed(stats))
return signalPassFailure();
Expand Down
13 changes: 13 additions & 0 deletions test/Dialect/Synth/functional-reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ hw.module @test_supported_ops(in %a: i1, in %b: i1, in %c: i1,
hw.output %0, %1, %2, %3, %4, %5 : i1, i1, i1, i1, i1, i1
}

// Test that the pass works on non-hw ops (func.func).
// CHECK-LABEL: func @test_func
func.func @test_func(%a: i1, %b: i1, %c: i1, %d: i1) -> (i1, i1) {
// CHECK: %[[R0:.+]] = synth.aig.and_inv %arg0, not %arg1, %arg2, not %arg3
// CHECK-NEXT: %[[CHOICE:.+]] = synth.choice
// CHECK: return %[[CHOICE]], %[[CHOICE]]
%0 = synth.aig.and_inv %a, not %b : i1
%1 = synth.aig.and_inv %c, not %d : i1
%2 = synth.aig.and_inv %0, %1 {synth.test.fc_equiv_class = 20} : i1
%3 = synth.aig.and_inv %a, not %b, %c, not %d {synth.test.fc_equiv_class = 20} : i1
func.return %2, %3 : i1, i1
}

// CHECK-LABEL: hw.module @test_inversion_equiv
hw.module @test_inversion_equiv(in %a: i1, in %b: i1, out out0: i1, out out1: i1) {
// CHECK: %[[AND:.+]] = synth.aig.and_inv not %a, not %b
Expand Down
9 changes: 9 additions & 0 deletions test/Dialect/Synth/maximum-and-cover.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ hw.module @InvertedNoCollapse(in %a: i1, in %b: i1, in %c: i1, out o1: i1) {
hw.output %1 : i1
}

// CHECK-LABEL: @func
func.func @func(%a: i1, %b: i1, %c: i1, %d: i1) -> (i1) {
// CHECK-NEXT: %[[AND:.+]] = synth.aig.and_inv %arg0, %arg1, not %arg2, %arg3 : i1
// CHECK-NEXT: return %[[AND]] : i1
%0 = synth.aig.and_inv %a, %b : i1
%1 = synth.aig.and_inv %0, not %c, %d : i1
func.return %1 : i1
}

// CHECK-LABEL: @ComplexTree
hw.module @ComplexTree(in %a: i1, in %b: i1, in %c: i1, in %d: i1, in %e: i1, in %f: i1, in %g: i1, out o1: i1) {
// CHECK-NEXT: %[[AND0:.+]] = synth.aig.and_inv %d, not %e : i1
Expand Down
Loading