diff --git a/include/circt/Dialect/Synth/Transforms/SynthPasses.td b/include/circt/Dialect/Synth/Transforms/SynthPasses.td index c105d8108720..72289ca84193 100644 --- a/include/circt/Dialect/Synth/Transforms/SynthPasses.td +++ b/include/circt/Dialect/Synth/Transforms/SynthPasses.td @@ -288,7 +288,7 @@ def SOPBalancing : CutRewriterPassBase<"synth-sop-balancing", "hw::HWModuleOp"> let dependentDialects = ["synth::SynthDialect", "comb::CombDialect"]; } -def FunctionalReduction : Pass<"synth-functional-reduction", "hw::HWModuleOp"> { +def FunctionalReduction : Pass<"synth-functional-reduction"> { let summary = "Functional reduction by simulation-guided candidate merging"; let description = [{ This pass performs functional reduction for single-bit `synth` logic. It diff --git a/lib/Dialect/Synth/Transforms/FunctionalReduction.cpp b/lib/Dialect/Synth/Transforms/FunctionalReduction.cpp index 26e32c8e59d4..d7cf96b26b28 100644 --- a/lib/Dialect/Synth/Transforms/FunctionalReduction.cpp +++ b/lib/Dialect/Synth/Transforms/FunctionalReduction.cpp @@ -294,7 +294,7 @@ void FunctionalReductionSATBuilder::encodeValue(Value value) { class FunctionalReductionSolver { public: - FunctionalReductionSolver(hw::HWModuleOp module, unsigned numPatterns, + FunctionalReductionSolver(Operation *module, unsigned numPatterns, unsigned seed, bool testTransformation, std::unique_ptr satSolver) : module(module), numPatterns(numPatterns), seed(seed), @@ -334,8 +334,8 @@ class FunctionalReductionSolver { static bool matchesTestEquivClass(Value lhs, Value rhs); EquivResult verifyEquivalence(Value lhs, Value rhs, bool inverted); - // Module being processed - hw::HWModuleOp module; + // Operation being processed + Operation *module; // Configuration unsigned numPatterns; @@ -425,18 +425,23 @@ void FunctionalReductionSolver::initializeSATState() { //===----------------------------------------------------------------------===// void FunctionalReductionSolver::collectValues() { - // Collect block arguments (primary inputs) that are i1 - for (auto arg : module.getBodyBlock()->getArguments()) { - if (arg.getType().isInteger(1)) { - primaryInputs.push_back(arg); - allValues.push_back(arg); + // Collect block arguments of the direct regions of `module` as primary + // inputs. These are the "ports" of the logic network being analyzed. + for (auto ®ion : module->getRegions()) { + for (auto &block : region) { + for (auto arg : block.getArguments()) { + if (arg.getType().isInteger(1)) { + primaryInputs.push_back(arg); + allValues.push_back(arg); + } + } } } // Walk operations and collect i1 results - // - AIG operations: add to allValues for simulation + // - AIG/comb operations: add to allValues for simulation // - Unknown operations: treat as inputs (assign random patterns) - module.walk([&](Operation *op) { + module->walk([&](Operation *op) { for (auto result : op->getResults()) { if (!result.getType().isInteger(1)) continue; @@ -450,6 +455,23 @@ void FunctionalReductionSolver::collectValues() { } }); + // Collect any i1 operands of simulatable ops that are defined outside the + // scope of `module` (e.g. block arguments of a nested op when the solver is + // run on a container op). Treat them as additional primary inputs so that + // simulation patterns are assigned to them before propagation. + llvm::DenseSet allValuesSet(allValues.begin(), allValues.end()); + for (auto value : allValues) { + Operation *op = value.getDefiningOp(); + if (!op || !isFunctionalReductionSimulatableOp(op)) + continue; + for (auto operand : op->getOperands()) { + if (operand.getType().isInteger(1) && !allValuesSet.count(operand)) { + primaryInputs.push_back(operand); + allValuesSet.insert(operand); + } + } + } + LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Collected " << primaryInputs.size() << " primary inputs (including unknown ops) and " @@ -614,7 +636,7 @@ void FunctionalReductionSolver::mergeEquivalentNodes() { if (provenEquivalences.empty()) return; - mlir::OpBuilder builder(module.getContext()); + mlir::OpBuilder builder(module->getContext()); for (auto &provenEquivSet : provenEquivalences) { auto &[representative, members] = provenEquivSet; if (members.empty()) @@ -751,18 +773,17 @@ struct FunctionalReductionPass } void runOnOperation() override { - auto module = getOperation(); + Operation *op = getOperation(); LLVM_DEBUG(llvm::dbgs() << "Running FunctionalReduction pass on " - << module.getName() << "\n"); + << op->getName() << "\n"); if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) { - module.emitError() + op->emitError() << "'num-random-patterns' must be a positive multiple of 64"; return signalPassFailure(); } if (conflictLimit < -1) { - module.emitError() - << "'conflict-limit' must be greater than or equal to -1"; + op->emitError() << "'conflict-limit' must be greater than or equal to -1"; return signalPassFailure(); } @@ -770,17 +791,16 @@ struct FunctionalReductionPass if (!testTransformation) { satSolver = createFunctionalReductionSATSolver(this->satSolver); if (!satSolver) { - module.emitError() << "unsupported or unavailable SAT solver '" - << this->satSolver - << "' (expected auto, z3, or cadical)"; + op->emitError() << "unsupported or unavailable SAT solver '" + << this->satSolver + << "' (expected auto, z3, or cadical)"; return signalPassFailure(); } satSolver->setConflictLimit(static_cast(conflictLimit)); } - FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed, - testTransformation, - std::move(satSolver)); + FunctionalReductionSolver fcSolver( + op, numRandomPatterns, seed, testTransformation, std::move(satSolver)); auto stats = fcSolver.run(); if (failed(stats)) return signalPassFailure(); diff --git a/test/Dialect/Synth/functional-reduction.mlir b/test/Dialect/Synth/functional-reduction.mlir index aed8085817d6..c84245582ff0 100644 --- a/test/Dialect/Synth/functional-reduction.mlir +++ b/test/Dialect/Synth/functional-reduction.mlir @@ -42,6 +42,19 @@ hw.module @test_supported_ops(in %a: i1, in %b: i1, in %c: i1, hw.output %0, %1, %2, %3, %4, %5 : i1, i1, i1, i1, i1, i1 } +// Test that the pass works on non-hw ops (func.func). +// CHECK-LABEL: func @test_func +func.func @test_func(%a: i1, %b: i1, %c: i1, %d: i1) -> (i1, i1) { + // CHECK: %[[R0:.+]] = synth.aig.and_inv %arg0, not %arg1, %arg2, not %arg3 + // CHECK-NEXT: %[[CHOICE:.+]] = synth.choice + // CHECK: return %[[CHOICE]], %[[CHOICE]] + %0 = synth.aig.and_inv %a, not %b : i1 + %1 = synth.aig.and_inv %c, not %d : i1 + %2 = synth.aig.and_inv %0, %1 {synth.test.fc_equiv_class = 20} : i1 + %3 = synth.aig.and_inv %a, not %b, %c, not %d {synth.test.fc_equiv_class = 20} : i1 + func.return %2, %3 : i1, i1 +} + // CHECK-LABEL: hw.module @test_inversion_equiv hw.module @test_inversion_equiv(in %a: i1, in %b: i1, out out0: i1, out out1: i1) { // CHECK: %[[AND:.+]] = synth.aig.and_inv not %a, not %b diff --git a/test/Dialect/Synth/maximum-and-cover.mlir b/test/Dialect/Synth/maximum-and-cover.mlir index c90e7357db61..aad42b21bb41 100644 --- a/test/Dialect/Synth/maximum-and-cover.mlir +++ b/test/Dialect/Synth/maximum-and-cover.mlir @@ -29,6 +29,15 @@ hw.module @InvertedNoCollapse(in %a: i1, in %b: i1, in %c: i1, out o1: i1) { hw.output %1 : i1 } +// CHECK-LABEL: @func +func.func @func(%a: i1, %b: i1, %c: i1, %d: i1) -> (i1) { + // CHECK-NEXT: %[[AND:.+]] = synth.aig.and_inv %arg0, %arg1, not %arg2, %arg3 : i1 + // CHECK-NEXT: return %[[AND]] : i1 + %0 = synth.aig.and_inv %a, %b : i1 + %1 = synth.aig.and_inv %0, not %c, %d : i1 + func.return %1 : i1 +} + // CHECK-LABEL: @ComplexTree hw.module @ComplexTree(in %a: i1, in %b: i1, in %c: i1, in %d: i1, in %e: i1, in %f: i1, in %g: i1, out o1: i1) { // CHECK-NEXT: %[[AND0:.+]] = synth.aig.and_inv %d, not %e : i1