diff --git a/include/circt/Support/SparseOpSCC.h b/include/circt/Support/SparseOpSCC.h new file mode 100644 index 000000000000..c16d0b2e70df --- /dev/null +++ b/include/circt/Support/SparseOpSCC.h @@ -0,0 +1,524 @@ +//===- SparseOpSCCs.h - SCC analysis on sparse op subgraphs ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Collect strongly connected components (SCCs) in the (filtered) def-use graph +// of MLIR operations, starting from a sparse set of seed operations. +// +// Graph model +// ----------- +// Each operation is a node. A directed edge runs from op A to op B if B uses +// one of A's results. The traversal direction is configurable: +// - OpSCCDirection::Forward -- follow edges from defining ops to uses. +// - OpSCCDirection::Backward -- follow edges from uses to defining op. +// +// SCC classification +// ------------------ +// An SCC is either: +// - Trivial: a single op with no self-loop. Represented as +// mlir::Operation * inside an OpSCC value. +// - Cyclic: a group of mutually-reachable ops, or a single op with a +// self-loop. Represented as a CyclicOpSCC inside an OpSCC value. +// +// Filtering +// --------- +// An optional OpSCCFilter predicate can be supplied to the constructor to +// prevent the traversal over certain edges of the graph. The first argument +// contains the operation into which the traversal would lead. The second +// argument contains the edge's destination operand. For forward traversal +// the operand's owner is identical to the first argument. For reverse +// traversal the first argument is identical to the operand's defining +// operation. +// +// Output ordering +// --------------- +// SCCs are available in a topological order of the condensation DAG via +// the iterator returned by topological(), and in the reverse via +// reverseTopological(). The returned order is deterministic (identical graphs +// will result in an identical order). However, if more than one topological +// order exists, there is no guarantee on the specific order. Equally, the +// order of operations within an SCC is deterministic but unspecified. +// +// Blocks and Regions +// ------------------ +// The traversal does not follow through block arguments. It does not consider +// control flow. It will descend into / ascend from regions without considering +// the parent operation. The filter predicate can be used to restrict the +// traversal to certain blocks or regions. +// +// Operation Graph Mutation +// ------------------------ +// The SparseOpSCC class internally stores the result of the SCC analysis +// and is only updated when visit(...) is called. The IR should not be mutated +// between visit calls. Calling visit invalidates all iterators. +// It is safe to mutate the IR while iterating. However, the iteration sequence +// may contain invalid operation pointers, if the underlying operation is erased +// after visiting the graph. To reflect changes to the graph in the analysis, +// reset() must be called and the graph must be re-visited. +// +// Usage examples +// ------------- +// +// Check if seedOp can be reached from someOp: +// +// SparseOpSCC sccs(regFilter); +// sccs.visit(seedOp); +// if (OpSCC someScc = sccs.getSCC(someOp)) { +// if (someScc == sccs.getSCC(seedOp)) { +// if (auto cycScc = llvm::dyn_cast(someScc)) { +// // seedOp and someOp are on at least one common cycle +// if (cycScc.size() == 1) { +// // seedOp and someOp are equal and there is a self-loop (i.e., at +// // least one operand of seedOp is a result of itself) +// } +// } else { +// // seedOp and someOp are equal and there is no self-loop +// // (trivial SCC) +// } +// } else { +// // seedOp is reachable from someOp but not the other way around +// // (someOp discovered during backwards traversal, but different SCCs) +// } +// } else { +// // seedOp is not reachable from someOp (someOp not discovered during +// // backwards traversal) +// } +// +// +// Collect all ops reachable from seedOp, excluding register ops, and process +// them in topological order: +// +// auto regFilter = [](Operation *op, OpOperand&) { +// return !isa(op); +// }; +// SparseOpSCC sccs(regFilter); +// sccs.visit(seedOp); +// +// for (OpSCC entry : sccs.topological()) { +// if (Operation *op = llvm::dyn_cast(entry)) { +// // Trivial SCC: a single op with no cycle. +// processSingle(op); +// } else { +// // Cyclic SCC: a group of mutually-reachable ops (or a self-loop). +// for (Operation *op : llvm::cast(entry)) +// processInCycle(op); +// } +// } +// +// Alternative filter that traverses registers through their clock and reset +// values but not the "next" data values: +// +// auto regEdgeFilter = [](Operation*, OpOperand& operand) { +// if (auto regOp = dyn_cast(operand.getOwner())) +// return operand != regOp.getNextMutable(); +// return true; +// }; +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_SPARSEOPSCC_H +#define CIRCT_SUPPORT_SPARSEOPSCC_H + +#include "mlir/IR/Operation.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PointerEmbeddedInt.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace circt { + +/// Filter predicate passed to the SparseOpSCC constructor. Return `true` to +/// include an edge in the traversal, `false` to skip it. The first argument is +/// the operation the traversal would enter. The second argument is the +/// `OpOperand` being followed: for forward traversal its owner equals the +/// first argument; for backward traversal its defining op equals the first +/// argument. +using OpSCCFilter = std::function; +namespace detail { +/// Backing storage for a cyclic SCC (implementation detail). +using CyclicOpSCCStorage = llvm::SmallVector; +} // namespace detail + +/// A cyclic SCC: a pointer-sized, directly-iterable reference to a group of +/// mutually-reachable operations (or a single op with a self-loop). +/// +/// Instances are obtained via llvm::cast on an OpSCC entry. +/// The referenced storage is owned by the SparseOpSCC that produced the entry. +class CyclicOpSCC { +public: + using iterator = detail::CyclicOpSCCStorage::const_iterator; + + CyclicOpSCC() : storage(nullptr) {} + CyclicOpSCC(const detail::CyclicOpSCCStorage *storage) : storage(storage) {} + + iterator begin() const { return storage->begin(); } + iterator end() const { return storage->end(); } + size_t size() const { return storage->size(); } + mlir::Operation *const *data() const { return storage->data(); } + mlir::Operation *operator[](size_t i) const { return (*storage)[i]; } + + operator bool() const { return storage != nullptr; } + + bool operator==(CyclicOpSCC other) const { return storage == other.storage; } + bool operator!=(CyclicOpSCC other) const { return storage != other.storage; } + + // Interface for PointerLikeTypeTraits. + void *getAsVoidPointer() const { + return const_cast(storage); + } + static CyclicOpSCC getFromVoidPointer(void *p) { + return CyclicOpSCC(static_cast(p)); + } + static constexpr int NumLowBitsAvailable = llvm::PointerLikeTypeTraits< + const detail::CyclicOpSCCStorage *>::NumLowBitsAvailable; + +private: + const detail::CyclicOpSCCStorage *storage; +}; + +} // namespace circt + +namespace llvm { +template <> +struct PointerLikeTypeTraits { + static void *getAsVoidPointer(circt::CyclicOpSCC scc) { + return scc.getAsVoidPointer(); + } + static circt::CyclicOpSCC getFromVoidPointer(void *p) { + return circt::CyclicOpSCC::getFromVoidPointer(p); + } + static constexpr int NumLowBitsAvailable = + circt::CyclicOpSCC::NumLowBitsAvailable; +}; +} // namespace llvm + +namespace circt { + +/// One entry in the SCC output: a null sentinel, a trivial (non-cyclic) +/// operation, or a cyclic group. Use llvm::isa / llvm::cast / llvm::dyn_cast +/// to distinguish. +/// Note: void * must be placed first in the union so that the all-zero +/// (default-constructed) state identifies unambiguously as invalid, not as a +/// null Operation*. +using OpSCC = llvm::PointerUnion; + +/// Traversal direction for SparseOpSCC. +/// - Forward: follow def-use edges forward (defining op -> users). +/// - Backward: follow def-use edges backward (user -> defining op). +enum class OpSCCDirection { Forward, Backward }; + +template +class SparseOpSCC; + +namespace detail { +using OpSccEmbeddedIndex = llvm::PointerEmbeddedInt; +using OpOrIndex = llvm::PointerUnion; + +// Iterator template resolving indices to CyclicOpSCC +template +class OpSCCIterator final + : public llvm::mapped_iterator_base, + BaseIteratorT, OpSCC> { +public: + using llvm::mapped_iterator_base, BaseIteratorT, + OpSCC>::mapped_iterator_base; + + OpSCC mapElement(OpOrIndex opOrIndex) const { + if (llvm::isa(opOrIndex)) + return llvm::cast(opOrIndex); + unsigned index = llvm::cast(opOrIndex); + return CyclicOpSCC(&cyclicSccs[index]); + } + +private: + template + friend class circt::SparseOpSCC; + + OpSCCIterator(BaseIteratorT it, + const llvm::ArrayRef cyclicSccs) + : llvm::mapped_iterator_base, BaseIteratorT, + OpSCC>(it), + cyclicSccs(cyclicSccs) {} + + const llvm::ArrayRef cyclicSccs; +}; + +} // namespace detail + +/// Iterative Tarjan SCC analysis on a sparse subgraph of MLIR operations. +/// +/// Call visit() with one or more seed operations to trigger the DFS. Results +/// accumulate across multiple visit() calls, so the discovered subgraph can be +/// expanded incrementally. +/// +/// The optional filter passed to the constructor is applied to every +/// discovered edge before it is traversed. An edge that fails the filter is +/// treated as if it did not exist in the graph. +/// +/// Iterators obtained from topological() and reverseTopological() hold a +/// reference into this object and are invalidated by calling visit() or +// reset(). +template +class SparseOpSCC { +public: + explicit SparseOpSCC(OpSCCFilter shouldTraverseFn = {}) + : shouldTraverseFn(shouldTraverseFn) {} + + /// Clear all accumulated state. + void reset() { + opToSccIndex.clear(); + sccs.clear(); + cyclicSccs.clear(); + } + + /// Seed `op` into the DFS if it has not already been discovered. + void visit(mlir::Operation *op) { + if (!opToSccIndex.contains(op)) + tarjanImpl(op); + } + + /// Visit each operation in `ops`, skipping already-discovered ones. + void visit(llvm::ArrayRef ops) { + for (auto *op : ops) + visit(op); + } + + /// Return true if `op` was discovered (as a seed or transitively) by any + /// previous visit() call. + bool hasDiscovered(mlir::Operation *op) const { + return opToSccIndex.contains(op); + } + + /// Return the SCC that `op` belongs to. If the operation has not been + /// discovered, it returns a `nullptr` sentinel. + OpSCC getSCC(mlir::Operation *op) const { + auto it = opToSccIndex.find(op); + if (it == opToSccIndex.end()) + return OpSCC(nullptr); + detail::OpOrIndex entry = sccs[it->second]; + if (llvm::isa(entry)) + return OpSCC(llvm::cast(entry)); + unsigned cyclicIdx = llvm::cast(entry); + return OpSCC(CyclicOpSCC(&cyclicSccs[cyclicIdx])); + } + + /// Number of operations discovered so far across all visit() calls. + unsigned getNumDiscovered() const { return opToSccIndex.size(); } + /// Total number of SCC entries emitted (trivial ops + cyclic groups). + unsigned getNumSCCs() const { return sccs.size(); } + /// Number of cyclic SCC groups (excludes trivial ops). + unsigned getNumCyclicSCCs() const { return cyclicSccs.size(); } + + /// Iterate over SCCs in topological order (sources/seeds first, leaves last). + auto topological() const { + return llvm::iterator_range(topological_begin(), topological_end()); + } + + // NOLINTNEXTLINE(readability-identifier-naming) + auto topological_begin() const { + if constexpr (Direction == OpSCCDirection::Backward) + return detail::OpSCCIterator( + sccs.begin(), cyclicSccs); + else + return detail::OpSCCIterator< + typename decltype(sccs)::const_reverse_iterator>(sccs.rbegin(), + cyclicSccs); + } + + // NOLINTNEXTLINE(readability-identifier-naming) + auto topological_end() const { + if constexpr (Direction == OpSCCDirection::Backward) + return detail::OpSCCIterator( + sccs.end(), cyclicSccs); + else + return detail::OpSCCIterator< + typename decltype(sccs)::const_reverse_iterator>(sccs.rend(), + cyclicSccs); + } + + /// Iterate over SCCs in reverse topological order (leaves first). + auto reverseTopological() const { + return llvm::iterator_range(reverseTopological_begin(), + reverseTopological_end()); + } + + // NOLINTNEXTLINE(readability-identifier-naming) + auto reverseTopological_begin() const { + if constexpr (Direction == OpSCCDirection::Forward) + return detail::OpSCCIterator( + sccs.begin(), cyclicSccs); + else + return detail::OpSCCIterator< + typename decltype(sccs)::const_reverse_iterator>(sccs.rbegin(), + cyclicSccs); + } + + // NOLINTNEXTLINE(readability-identifier-naming) + auto reverseTopological_end() const { + if constexpr (Direction == OpSCCDirection::Forward) + return detail::OpSCCIterator( + sccs.end(), cyclicSccs); + else + return detail::OpSCCIterator< + typename decltype(sccs)::const_reverse_iterator>(sccs.rend(), + cyclicSccs); + } + +private: + // DFS stack frame for forward traversal. Skips over unused results. + struct ForwardFrame { + mlir::Operation *op; + std::optional useIt; + unsigned resultIdx; + bool hasSelfLoop = false; + + explicit ForwardFrame(mlir::Operation *op) + : op(op), useIt(std::nullopt), resultIdx(0) { + if (op->getNumResults() > 0) + useIt = op->getResult(0).use_begin(); + } + + mlir::Operation *nextChild(OpSCCFilter shouldTraverseFn) { + while (resultIdx < op->getNumResults()) { + auto useEnd = op->getResult(resultIdx).use_end(); + while (*useIt != useEnd) { + mlir::OpOperand &use = **useIt; + ++(*useIt); + if (!shouldTraverseFn || shouldTraverseFn(use.getOwner(), use)) + return use.getOwner(); + } + ++resultIdx; + if (resultIdx < op->getNumResults()) + useIt = op->getResult(resultIdx).use_begin(); + } + return nullptr; + } + }; + + // DFS stack frame for backward traversal. Skips over block arguments. + struct BackwardFrame { + mlir::Operation *op; + unsigned operandIdx; + bool hasSelfLoop = false; + + explicit BackwardFrame(mlir::Operation *op) : op(op), operandIdx(0) {} + + mlir::Operation *nextChild(OpSCCFilter shouldTraverseFn) { + while (operandIdx < op->getNumOperands()) { + mlir::OpOperand &operand = op->getOpOperand(operandIdx++); + auto *defOp = operand.get().getDefiningOp(); + if (defOp && (!shouldTraverseFn || shouldTraverseFn(defOp, operand))) + return defOp; + } + return nullptr; + } + }; + + using FrameT = std::conditional_t; + + void tarjanImpl(mlir::Operation *startOp) { + unsigned nextIdx = 0; + llvm::SmallDenseMap> + idxAndLowLinkMap; + llvm::SetVector sccStack; + llvm::SmallVector dfsStack; + + auto pushFrame = [&](mlir::Operation *op) { + idxAndLowLinkMap[op] = {nextIdx, nextIdx}; + ++nextIdx; + sccStack.insert(op); + dfsStack.push_back(FrameT(op)); + }; + + pushFrame(startOp); + + while (!dfsStack.empty()) { + FrameT &frame = dfsStack.back(); + mlir::Operation *op = frame.op; + + if (auto *child = frame.nextChild(shouldTraverseFn)) { + if (child == op) { + // Self-loop — record it in the frame; no lowlink update needed. + frame.hasSelfLoop = true; + } else { + auto it = idxAndLowLinkMap.find(child); + if (it != idxAndLowLinkMap.end()) { + // Already seen in this DFS. + if (sccStack.contains(child)) + // Back edge — update lowlink. + idxAndLowLinkMap[op].second = + std::min(idxAndLowLinkMap[op].second, it->second.first); + // else: forward/cross edge within this DFS — ignore. + } else if (!opToSccIndex.contains(child)) { + // Not yet seen in any DFS — recurse. + pushFrame(child); + } + // else: completed in a previous visit() call — cross edge, ignore. + } + continue; + } + + // All children processed — backtrack. + bool selfLoop = frame.hasSelfLoop; + auto [opIndex, opLowLink] = idxAndLowLinkMap.at(op); + dfsStack.pop_back(); + + // If op is the root of its SCC, pop and emit it. + if (opLowLink == opIndex) { + detail::CyclicOpSCCStorage sccOps; + do { + sccOps.push_back(sccStack.pop_back_val()); + } while (sccOps.back() != op); + + // Store the SCC index of the discovered ops + unsigned sccIdx = sccs.size(); + for (auto *sccOp : sccOps) { + bool inserted = opToSccIndex.insert({sccOp, sccIdx}).second; + (void)inserted; + assert(inserted && "Unexpectedly revisited node"); + } + + // Insert the pointers into the persistent storage + if (sccOps.size() == 1 && !selfLoop) { + sccs.push_back(detail::OpOrIndex(sccOps.front())); + } else { + unsigned cyclicIdx = cyclicSccs.size(); + cyclicSccs.emplace_back(std::move(sccOps)); + sccs.push_back(detail::OpOrIndex(cyclicIdx)); + } + continue; + } + + // Not an SCC root — back-propagate lowlink to the parent frame. + auto &parentLowLink = idxAndLowLinkMap.at(dfsStack.back().op).second; + parentLowLink = std::min(parentLowLink, opLowLink); + } + assert(sccStack.empty()); + } + + /// Optional edge filter supplied at construction time. + OpSCCFilter shouldTraverseFn; + + /// Maps each visited op to the index of its SCC in `sccs`. Persists across + /// visit() calls and is the authoritative "already visited" guard. + llvm::SmallDenseMap opToSccIndex; + /// Flat list of SCC entries emitted by Tarjan, in emission order. + /// Trivial SCCs are stored directly. Cyclic SCCs are stored as index into + /// the `cyclicSccs` vector. + llvm::SmallVector sccs; + /// Backing storage for cyclic SCCs; CyclicOpSCC holds a pointer into here. + llvm::SmallVector cyclicSccs; +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_SPARSEOPSCC_H diff --git a/unittests/Support/CMakeLists.txt b/unittests/Support/CMakeLists.txt index 489869937ac7..cf30b563fc3b 100644 --- a/unittests/Support/CMakeLists.txt +++ b/unittests/Support/CMakeLists.txt @@ -3,10 +3,16 @@ add_circt_unittest(CIRCTSupportTests JSONTest.cpp PrettyPrinterTest.cpp SATSolverTest.cpp + SparseOpSCCTest.cpp TruthTableTest.cpp ) target_link_libraries(CIRCTSupportTests PRIVATE CIRCTSupport + CIRCTHW + CIRCTComb + CIRCTSeq + MLIRIR + MLIRParser ) diff --git a/unittests/Support/SparseOpSCCTest.cpp b/unittests/Support/SparseOpSCCTest.cpp new file mode 100644 index 000000000000..45ec41f13216 --- /dev/null +++ b/unittests/Support/SparseOpSCCTest.cpp @@ -0,0 +1,964 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/SparseOpSCC.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/Seq/SeqOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Parser/Parser.h" +#include "llvm/ADT/SmallVector.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace circt; + +namespace { + +// Graph region with a simple chain: and -> or -> output +// (each op's result is consumed by the next) +const char *ir = R"MLIR( + hw.module private @test(in %a: i1, in %b: i1, out x: i1) { + %and = comb.and %a, %b : i1 + %or = comb.or %and, %a : i1 + hw.output %or : i1 + } +)MLIR"; + +// Graph with a register-and cycle: reg uses and's result as next/resetValue, +// and uses reg's result as an operand. +// Forward edges: reg->and, and->{reg(next), reg(resetValue), output} +const char *cycleIr = R"MLIR( + hw.module private @cycle(in %clock: !seq.clock, in %reset: i1, in %a: i1, out x: i1) { + %reg = seq.firreg %and clock %clock reset sync %reset, %and : i1 + %and = comb.and %reg, %a : i1 + hw.output %and : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, SimpleChain) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = parseSourceString(ir, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("test"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *andOp = &*it++; + Operation *orOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + // Forward reachability from andOp. + { + SparseOpSCC opScc; + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumDiscovered(), 3u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto topoIt = opScc.topological_begin(); + EXPECT_EQ(andOp, cast(*(topoIt++))); + EXPECT_EQ(orOp, cast(*(topoIt++))); + EXPECT_EQ(outputOp, cast(*(topoIt++))); + EXPECT_EQ(topoIt, opScc.topological_end()); + + auto revTopoIt = opScc.reverseTopological_begin(); + EXPECT_EQ(outputOp, cast(*(revTopoIt++))); + EXPECT_EQ(orOp, cast(*(revTopoIt++))); + EXPECT_EQ(andOp, cast(*(revTopoIt++))); + EXPECT_EQ(revTopoIt, opScc.reverseTopological_end()); + + // getSCC maps each visited op to its own trivial SCC entry. + EXPECT_EQ(cast(opScc.getSCC(andOp)), andOp); + EXPECT_EQ(cast(opScc.getSCC(orOp)), orOp); + EXPECT_EQ(cast(opScc.getSCC(outputOp)), outputOp); + } + + // Inverse reachability from outputOp. + // Follows operands backward; reverse topo of inverse = forward topo: + // [andOp, orOp, outputOp]. + { + SparseOpSCC opScc; + opScc.visit(outputOp); + + EXPECT_EQ(opScc.getNumDiscovered(), 3u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto topoIt = opScc.topological_begin(); + EXPECT_EQ(andOp, cast(*(topoIt++))); + EXPECT_EQ(orOp, cast(*(topoIt++))); + EXPECT_EQ(outputOp, cast(*(topoIt++))); + EXPECT_EQ(topoIt, opScc.topological_end()); + + auto revTopoIt = opScc.reverseTopological_begin(); + EXPECT_EQ(outputOp, cast(*(revTopoIt++))); + EXPECT_EQ(orOp, cast(*(revTopoIt++))); + EXPECT_EQ(andOp, cast(*(revTopoIt++))); + EXPECT_EQ(revTopoIt, opScc.reverseTopological_end()); + } +} + +// reset() clears all accumulated state including cyclic SCC storage; +// re-visiting produces fresh results. +TEST(SparseOpSCCsTest, Reset) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = parseSourceString(cycleIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("cycle"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *regOp = &*it++; + Operation *andOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + SparseOpSCC opScc; + opScc.visit(andOp); // discovers CyclicOpSCC{reg,and} + trivial outputOp + EXPECT_EQ(opScc.getNumDiscovered(), 3u); + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + opScc.reset(); + + // All state cleared, including cyclic SCC storage. + EXPECT_EQ(opScc.getNumDiscovered(), 0u); + EXPECT_EQ(opScc.getNumSCCs(), 0u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + EXPECT_FALSE(opScc.hasDiscovered(regOp)); + EXPECT_FALSE(opScc.hasDiscovered(andOp)); + EXPECT_FALSE(opScc.hasDiscovered(outputOp)); + EXPECT_EQ(opScc.topological_begin(), opScc.topological_end()); + EXPECT_EQ(opScc.reverseTopological_begin(), opScc.reverseTopological_end()); + + // Re-visiting with a boundary seed produces the correct fresh result. + opScc.visit(outputOp); // hw.output has no results: only outputOp discovered. + EXPECT_EQ(opScc.getNumDiscovered(), 1u); + EXPECT_EQ(opScc.getNumSCCs(), 1u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + EXPECT_TRUE(opScc.hasDiscovered(outputOp)); + EXPECT_FALSE(opScc.hasDiscovered(andOp)); +} + +// Passing an empty seed list leaves all counts at zero. +TEST(SparseOpSCCsTest, EmptySeeds) { + SparseOpSCC fwd; + fwd.visit(ArrayRef{}); + EXPECT_EQ(fwd.getNumDiscovered(), 0u); + EXPECT_EQ(fwd.getNumSCCs(), 0u); + EXPECT_EQ(fwd.getNumCyclicSCCs(), 0u); + + SparseOpSCC inv; + inv.visit(ArrayRef{}); + EXPECT_EQ(inv.getNumDiscovered(), 0u); + EXPECT_EQ(inv.getNumSCCs(), 0u); + EXPECT_EQ(inv.getNumCyclicSCCs(), 0u); +} + +// A seed that is a boundary node in the traversal direction is returned as a +// single trivial SCC with no further expansion. +// - Forward from outputOp: hw.output produces no results, so it has no +// forward edges; only outputOp is emitted. +// - Inverse from andOp: both operands are block arguments (no defining op), +// so there are no inverse edges; only andOp is emitted. +TEST(SparseOpSCCsTest, BoundarySeed) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = parseSourceString(ir, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("test"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *andOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + { + SparseOpSCC opScc; + opScc.visit(outputOp); + + EXPECT_EQ(opScc.getNumDiscovered(), 1u); + EXPECT_EQ(opScc.getNumSCCs(), 1u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto topoIt = opScc.topological_begin(); + EXPECT_EQ(cast(*(topoIt++)), outputOp); + EXPECT_EQ(topoIt, opScc.topological_end()); + + // getSCC returns a valid entry for the visited op and null for + // unvisited. + EXPECT_EQ(cast(opScc.getSCC(outputOp)), outputOp); + EXPECT_FALSE(static_cast(opScc.getSCC(andOp))); + } + + { + SparseOpSCC opScc; + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumDiscovered(), 1u); + EXPECT_EQ(opScc.getNumSCCs(), 1u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto topoIt = opScc.topological_begin(); + EXPECT_EQ(cast(*(topoIt++)), andOp); + EXPECT_EQ(topoIt, opScc.topological_end()); + } +} + +// When one seed is already reachable from another, it is visited during the +// first seed's DFS and silently skipped when the loop encounters it as a seed. +// The result is identical to seeding with just the upstream op. +TEST(SparseOpSCCsTest, OverlappingSeeds) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = parseSourceString(ir, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("test"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *andOp = &*it++; + Operation *orOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + SparseOpSCC opScc; + opScc.visit(andOp); + opScc.visit(orOp); // already visited via andOp — skipped + + EXPECT_EQ(opScc.getNumSCCs(), 3u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto topoIt = opScc.topological_begin(); + EXPECT_EQ(cast(*(topoIt++)), andOp); + EXPECT_EQ(cast(*(topoIt++)), orOp); + EXPECT_EQ(cast(*(topoIt++)), outputOp); + EXPECT_EQ(topoIt, opScc.topological_end()); +} + +// Incremental visit: two independent roots feeding a shared merge node. +// +// %av = comb.and %a, %a : i1 seed 1 (independent) +// %bv = comb.or %b, %b : i1 seed 2 (independent) +// %merge = comb.xor %av, %bv : i1 shared successor +// hw.output %merge : i1 +// +// visit(avOp) discovers {avOp, mergeOp, outputOp}. +// visit(bvOp) discovers only bvOp: mergeOp is already discovered and is +// treated as a cross-edge, not re-entered. +const char *incrementalIr = R"MLIR( + hw.module private @incremental(in %a: i1, in %b: i1, out x: i1) { + %av = comb.and %a, %a : i1 + %bv = comb.or %b, %b : i1 + %merge = comb.xor %av, %bv : i1 + hw.output %merge : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, IncrementalVisit) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(incrementalIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("incremental"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *avOp = &*it++; + Operation *bvOp = &*it++; + Operation *mergeOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + SparseOpSCC opScc; + + // First visit: avOp → mergeOp → outputOp. + opScc.visit(avOp); + EXPECT_EQ(opScc.getNumDiscovered(), 3u); + EXPECT_TRUE(opScc.hasDiscovered(avOp)); + EXPECT_TRUE(opScc.hasDiscovered(mergeOp)); + EXPECT_TRUE(opScc.hasDiscovered(outputOp)); + EXPECT_FALSE(opScc.hasDiscovered(bvOp)); + + // Second visit: bvOp is new; mergeOp/outputOp are cross-edges — not + // re-entered, so only bvOp is added. + opScc.visit(bvOp); + EXPECT_EQ(opScc.getNumDiscovered(), 4u); + EXPECT_TRUE(opScc.hasDiscovered(bvOp)); + EXPECT_EQ(opScc.getNumSCCs(), 4u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + // mergeOp's SCC was assigned by the first visit and is unchanged. + EXPECT_EQ(cast(opScc.getSCC(mergeOp)), mergeOp); + + // Reverse-topological (leaves first): outputOp, mergeOp, avOp, bvOp. + // bvOp trails because it was emitted by the second DFS. + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), mergeOp); + EXPECT_EQ(cast(*(revIt++)), avOp); + EXPECT_EQ(cast(*(revIt++)), bvOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); +} + +// Diamond: %and splits into two branches that join at %merge; no cycle. +// +// %and = comb.and %a, %b : i1 split node +// %or = comb.or %and, %a : i1 left branch +// %xor = comb.xor %and, %b : i1 right branch +// %merge = comb.or %or, %xor : i1 join node +// hw.output %merge : i1 +// +// Forward edges: and->{or,xor}, or->merge, xor->merge, merge->output. +// Reverse-topo order: outputOp, mergeOp, {orOp,xorOp} (order unspecified), +// andOp. +const char *diamondIr = R"MLIR( + hw.module private @diamond(in %a: i1, in %b: i1, out x: i1) { + %and = comb.and %a, %b : i1 + %or = comb.or %and, %a : i1 + %xor = comb.xor %and, %b : i1 + %merge = comb.or %or, %xor : i1 + hw.output %merge : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, DiamondNoCycle) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(diamondIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("diamond"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *andOp = &*it++; + Operation *orOp = &*it++; + Operation *xorOp = &*it++; + Operation *mergeOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + // Without filter: all five ops, orOp/xorOp order is DFS-dependent. + { + SparseOpSCC opScc; + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 5u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), mergeOp); + std::array middle = {cast(*(revIt++)), + cast(*(revIt++))}; + EXPECT_TRUE(llvm::is_contained(middle, orOp)); + EXPECT_TRUE(llvm::is_contained(middle, xorOp)); + EXPECT_EQ(cast(*(revIt++)), andOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } + + // Forward direction with edge filter blocking xorOp as destination: only + // the left branch (andOp->orOp->mergeOp->outputOp) is visited. xorOp is + // never reachable, so 4 trivial SCCs in deterministic reverseTopological + // order: outputOp, mergeOp, orOp, andOp. + { + auto filter = [&](Operation *op, OpOperand &) { return op != xorOp; }; + SparseOpSCC opScc(filter); + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 4u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), mergeOp); + EXPECT_EQ(cast(*(revIt++)), orOp); + EXPECT_EQ(cast(*(revIt++)), andOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } + + // Backward direction from outputOp with edge filter blocking xorOp as + // source: the backward path through xorOp is cut, so only the left branch + // (orOp) is visited (outputOp<-mergeOp<-orOp<-andOp). 4 trivial SCCs in + // reverseTopological order (sinks-first): outputOp, mergeOp, orOp, andOp. + { + auto filter = [&](Operation *src, OpOperand &) { return src != xorOp; }; + SparseOpSCC opScc(filter); + opScc.visit(outputOp); + + EXPECT_EQ(opScc.getNumSCCs(), 4u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), mergeOp); + EXPECT_EQ(cast(*(revIt++)), orOp); + EXPECT_EQ(cast(*(revIt++)), andOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } +} + +// Graph with a cycle: "and" uses "reg"'s result, "reg" uses "and"'s result. +// +// %reg = seq.firreg %and clock %clock reset sync %reset, %and : i1 +// %and = comb.and %reg, %a : i1 +// hw.output %and : i1 +// +// %and feeds reg.next (the cycle edge) and reg.resetValue. +// Block order: regOp, andOp, outputOp (terminator). +// Forward edges: reg->and, and->{reg(next), reg(resetValue), output}. + +TEST(SparseOpSCCsTest, CycleWithRegister) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = parseSourceString(cycleIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("cycle"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *regOp = &*it++; + Operation *andOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + // Without filter: "reg" and "and" form a cycle -> one CyclicOpSCC. + { + SparseOpSCC opScc; + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + + CyclicOpSCC scc = cast(*revIt++); + EXPECT_EQ(scc.size(), 2u); + EXPECT_TRUE(llvm::is_contained(scc, regOp)); + EXPECT_TRUE(llvm::is_contained(scc, andOp)); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + + // getSCC: both cycle members map to the same CyclicOpSCC entry. + EXPECT_EQ(cast(opScc.getSCC(regOp)), scc); + EXPECT_EQ(cast(opScc.getSCC(andOp)), scc); + EXPECT_EQ(cast(opScc.getSCC(outputOp)), outputOp); + } + + // With filter that excludes regOp: no cycle, both ops are trivial. + { + auto filter = [&](Operation *op, OpOperand &) { return op != regOp; }; + SparseOpSCC opScc(filter); + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto topoIt = opScc.topological_begin(); + EXPECT_EQ(cast(*(topoIt++)), andOp); + EXPECT_EQ(cast(*(topoIt++)), outputOp); + EXPECT_EQ(topoIt, opScc.topological_end()); + } + + // With edge filter blocking only reg's 'next' operand: %and also drives + // reg.resetValue, so the cycle persists through the reset edge. + { + auto regEdgeFilter = [](Operation *, OpOperand &operand) -> bool { + if (auto firReg = dyn_cast(operand.getOwner())) + return operand != firReg.getNextMutable(); + return true; + }; + SparseOpSCC opScc(regEdgeFilter); + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + + CyclicOpSCC scc = cast(*revIt++); + ASSERT_EQ(scc.size(), 2u); + EXPECT_TRUE(llvm::is_contained(scc, regOp)); + EXPECT_TRUE(llvm::is_contained(scc, andOp)); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } + + // Backward direction from outputOp with regEdgeFilter: the backward path + // through reg.resetValue keeps the cycle alive even though reg.next is + // blocked. reverseTopological order (sinks-first): outputOp, then + // CyclicOpSCC{regOp,andOp}. + { + auto regEdgeFilter = [](Operation *, OpOperand &operand) -> bool { + if (auto firReg = dyn_cast(operand.getOwner())) + return operand != firReg.getNextMutable(); + return true; + }; + SparseOpSCC opScc(regEdgeFilter); + opScc.visit(outputOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + CyclicOpSCC scc = cast(*revIt++); + ASSERT_EQ(scc.size(), 2u); + EXPECT_TRUE(llvm::is_contained(scc, regOp)); + EXPECT_TRUE(llvm::is_contained(scc, andOp)); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } +} + +// Two disjoint register cycles sharing a common downstream merge node. +// +// %reg1 = seq.firreg %and1 clock %clock : i1 -. cycle 1 +// %and1 = comb.and %reg1, %a : i1 -' +// %reg2 = seq.firreg %or1 clock %clock : i1 -. cycle 2 +// %or1 = comb.or %reg2, %a : i1 -' +// %xor = comb.xor %and1, %or1 : i1 merge node +// hw.output %xor : i1 +// +// Seeds: {and1Op, or1Op}. +// Expected reverse-topo: [outputOp, xorOp, SCC{reg1,and1}, SCC{reg2,or1}] +const char *twoCyclesIr = R"MLIR( + hw.module private @twocycles(in %clock: !seq.clock, in %a: i1, out x: i1) { + %reg1 = seq.firreg %and1 clock %clock : i1 + %and1 = comb.and %reg1, %a : i1 + %reg2 = seq.firreg %or1 clock %clock : i1 + %or1 = comb.or %reg2, %a : i1 + %xor = comb.xor %and1, %or1 : i1 + hw.output %xor : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, TwoCycles) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(twoCyclesIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("twocycles"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *reg1Op = &*it++; + Operation *and1Op = &*it++; + Operation *reg2Op = &*it++; + Operation *or1Op = &*it++; + Operation *xorOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + SparseOpSCC opScc; + opScc.visit(and1Op); + opScc.visit(or1Op); + + EXPECT_EQ(opScc.getNumSCCs(), 4u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 2u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), xorOp); + + // First cyclic SCC: reg1 <-> and1. + CyclicOpSCC scc0 = cast(*revIt++); + EXPECT_EQ(scc0.size(), 2u); + EXPECT_TRUE(llvm::is_contained(scc0, reg1Op)); + EXPECT_TRUE(llvm::is_contained(scc0, and1Op)); + + EXPECT_EQ(opScc.getSCC(reg1Op), opScc.getSCC(and1Op)); + EXPECT_NE(opScc.getSCC(reg1Op), opScc.getSCC(or1Op)); + EXPECT_NE(opScc.getSCC(reg1Op), opScc.getSCC(xorOp)); + EXPECT_NE(opScc.getSCC(reg1Op), opScc.getSCC(hwModule)); + + // Second cyclic SCC: reg2 <-> or1. + CyclicOpSCC scc1 = cast(*revIt++); + EXPECT_EQ(scc1.size(), 2u); + EXPECT_TRUE(llvm::is_contained(scc1, reg2Op)); + EXPECT_TRUE(llvm::is_contained(scc1, or1Op)); + + EXPECT_EQ(opScc.getSCC(reg2Op), opScc.getSCC(or1Op)); + EXPECT_NE(opScc.getSCC(or1Op), opScc.getSCC(and1Op)); + EXPECT_NE(opScc.getSCC(or1Op), opScc.getSCC(xorOp)); + EXPECT_NE(opScc.getSCC(or1Op), opScc.getSCC(hwModule)); + + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + + // Backward from xorOp: discovers both cycles plus xorOp itself. + // outputOp is downstream of xorOp in the forward graph and is not reached. + // reverseTopological: xorOp first (forward-graph leaf), then the two cyclic + // SCCs in unspecified order. + { + SparseOpSCC opScc; + opScc.visit(xorOp); + + EXPECT_EQ(opScc.getNumDiscovered(), 5u); // reg1, and1, reg2, or1, xor + EXPECT_EQ(opScc.getNumSCCs(), 3u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 2u); + EXPECT_FALSE(opScc.hasDiscovered(outputOp)); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), xorOp); + + std::array cycles = {cast(*revIt++), + cast(*revIt++)}; + EXPECT_TRUE(llvm::any_of(cycles, [&](CyclicOpSCC scc) { + return scc.size() == 2 && llvm::is_contained(scc, reg1Op) && + llvm::is_contained(scc, and1Op); + })); + EXPECT_TRUE(llvm::any_of(cycles, [&](CyclicOpSCC scc) { + return scc.size() == 2 && llvm::is_contained(scc, reg2Op) && + llvm::is_contained(scc, or1Op); + })); + + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } +} + +// One large SCC containing two internal cycles and five operations. +// +// %r0 = seq.firreg %and1 ... -. cycle A: r0 <-> and1 +// %and1 = comb.and %r0, %r2 -' (and1 also drives xor4 and output) +// %r2 = seq.firreg %or3 ... -. cycle B: r2 <-> or3 +// %or3 = comb.or %r2, %xor4 -' +// %xor4 = comb.xor %and1, %a bridge: and1 -> xor4 -> or3 -> r2 -> and1 +// +// All five are mutually reachable -> single CyclicOpSCC of size 5. +const char *twoInternalCyclesIr = R"MLIR( + hw.module private @largeSCC(in %clock: !seq.clock, in %a: i1, out x: i1) { + %r0 = seq.firreg %and1 clock %clock : i1 + %and1 = comb.and %r0, %r2 : i1 + %r2 = seq.firreg %or3 clock %clock : i1 + %or3 = comb.or %r2, %xor4 : i1 + %xor4 = comb.xor %and1, %a : i1 + hw.output %and1 : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, TwoInternalCycles) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(twoInternalCyclesIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("largeSCC"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *r0Op = &*it++; + Operation *and1Op = &*it++; + Operation *r2Op = &*it++; + Operation *or3Op = &*it++; + Operation *xor4Op = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + auto checkFiveOpSCC = [&](CyclicOpSCC scc) { + EXPECT_EQ(scc.size(), 5u); + EXPECT_TRUE(llvm::is_contained(scc, r0Op)); + EXPECT_TRUE(llvm::is_contained(scc, and1Op)); + EXPECT_TRUE(llvm::is_contained(scc, r2Op)); + EXPECT_TRUE(llvm::is_contained(scc, or3Op)); + EXPECT_TRUE(llvm::is_contained(scc, xor4Op)); + }; + + // Forward direction from and1Op: outputOp is the only leaf; the five-op SCC + // follows. + { + SparseOpSCC opScc; + opScc.visit(and1Op); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + auto cyclicSCC = cast(*revIt++); + checkFiveOpSCC(cyclicSCC); + EXPECT_TRUE(llvm::all_equal(llvm::map_range( + cyclicSCC, [&](Operation *op) -> OpSCC { return opScc.getSCC(op); }))); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } + + // Inverse direction from outputOp: the five-op SCC precedes outputOp + // (topological order: predecessors before successors). + { + SparseOpSCC opScc; + opScc.visit(outputOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto topoIt = opScc.topological_begin(); + auto cyclicSCC = cast(*topoIt++); + checkFiveOpSCC(cyclicSCC); + EXPECT_TRUE(llvm::all_equal(llvm::map_range( + cyclicSCC, [&](Operation *op) -> OpSCC { return opScc.getSCC(op); }))); + EXPECT_EQ(cast(*(topoIt++)), outputOp); + EXPECT_EQ(topoIt, opScc.topological_end()); + } +} + +// Self-loop: comb.and that uses its own result as one operand. +const char *selfLoopIr = R"MLIR( + hw.module private @selfloop(in %a: i1, out x: i1) { + %and = comb.and %and, %a : i1 + hw.output %and : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, SelfLoop) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(selfLoopIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("selfloop"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *andOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + // Without filter: andOp is a size-1 CyclicOpSCC due to self-loop. + { + SparseOpSCC opScc; + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(dyn_cast(*revIt), nullptr); + CyclicOpSCC scc = cast(*revIt++); + EXPECT_EQ(scc.size(), 1u); + EXPECT_EQ(scc[0], andOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } + + // With edge filter blocking the self-loop edge: andOp becomes trivial. + { + auto filter = [&](Operation *dest, OpOperand &) { return dest != andOp; }; + SparseOpSCC opScc(filter); + opScc.visit(andOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), andOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } +} + +// Self-loop via register: both 'next' and 'resetValue' feed back into the +// register itself, creating two independent self-loop edges. +// +// %reg = seq.firreg %reg clock %clock reset sync %reset, %reg : i1 +// +// Operand 0 (next) and operand 3 (resetValue) are both self-loop edges. +const char *selfLoopRegIr = R"MLIR( + hw.module private @selfloopReg(in %clock: !seq.clock, in %reset: i1, out x: i1) { + %reg = seq.firreg %reg clock %clock reset sync %reset, %reg : i1 + hw.output %reg : i1 + } +)MLIR"; + +TEST(SparseOpSCCsTest, SelfLoopRegWithEdgeFilter) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(selfLoopRegIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("selfloopReg"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *regOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + // Without filter: both self-loop edges are visible -> CyclicOpSCC of size 1. + { + SparseOpSCC opScc; + opScc.visit(regOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + CyclicOpSCC scc = cast(*revIt++); + ASSERT_EQ(scc.size(), 1u); + EXPECT_EQ(scc[0], regOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } + + // Filter blocking only 'next': resetValue self-loop edge is still + // traversable -> still classified as CyclicOpSCC. + { + auto regEdgeFilter = [](Operation *, OpOperand &operand) -> bool { + if (auto firReg = dyn_cast(operand.getOwner())) + return operand != firReg.getNextMutable(); + return true; + }; + SparseOpSCC opScc(regEdgeFilter); + opScc.visit(regOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 1u); + } + + // Filter blocking all edges into regOp: both self-loop edges blocked -> + // regOp becomes trivial. + { + auto filter = [&](Operation *dest, OpOperand &) { return dest != regOp; }; + SparseOpSCC opScc(filter); + opScc.visit(regOp); + + EXPECT_EQ(opScc.getNumSCCs(), 2u); + EXPECT_EQ(opScc.getNumCyclicSCCs(), 0u); + + auto revIt = opScc.reverseTopological_begin(); + EXPECT_EQ(cast(*(revIt++)), outputOp); + EXPECT_EQ(cast(*(revIt++)), regOp); + EXPECT_EQ(revIt, opScc.reverseTopological_end()); + } +} + +// IR for casting test: a reg<->and cycle, a leaf output, and a disconnected +// or-op that is never reached by a forward DFS from andOp. +const char *castingIr = R"MLIR( + hw.module private @casting(in %clock: !seq.clock, in %reset: i1, in %a: i1, out x: i1) { + %reg = seq.firreg %and clock %clock reset sync %reset, %and : i1 + %and = comb.and %reg, %a : i1 + %unused = comb.or %a, %a : i1 + hw.output %and : i1 + } +)MLIR"; + +// Verify isa / dyn_cast behaviour across all three OpSCC variants. +TEST(SparseOpSCCsTest, OpSCCCasting) { + MLIRContext context; + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + OwningOpRef module = + parseSourceString(castingIr, &context); + ASSERT_TRUE(module); + + SymbolTable symbolTable(module.get()); + auto hwModule = symbolTable.lookup("casting"); + ASSERT_TRUE(hwModule); + + auto it = hwModule.getBodyBlock()->begin(); + Operation *regOp = &*it++; + Operation *andOp = &*it++; + Operation *unusedOp = &*it++; + Operation *outputOp = hwModule.getBodyBlock()->getTerminator(); + + SparseOpSCC opScc; + opScc.visit(andOp); + + // Obtain one entry of each kind via getSCC. + OpSCC trivialEntry = opScc.getSCC(outputOp); // trivial: no cycle + OpSCC cyclicEntry = opScc.getSCC(andOp); // cyclic: reg <-> and + OpSCC nullEntry = opScc.getSCC(unusedOp); // not reachable from andOp + + // isa<> correctly identifies each variant. + EXPECT_TRUE(isa(trivialEntry)); + EXPECT_FALSE(isa(trivialEntry)); + + EXPECT_FALSE(isa(cyclicEntry)); + EXPECT_TRUE(isa(cyclicEntry)); + + EXPECT_FALSE(isa(nullEntry)); + EXPECT_FALSE(isa(nullEntry)); + + // An undiscovered entry is bool-false; present entries are bool-true. + EXPECT_FALSE(static_cast(nullEntry)); + EXPECT_TRUE(static_cast(trivialEntry)); + EXPECT_TRUE(static_cast(cyclicEntry)); + + // dyn_cast on known-present values. + EXPECT_EQ(dyn_cast(trivialEntry), outputOp); + EXPECT_TRUE(static_cast(dyn_cast(trivialEntry))); + EXPECT_FALSE(static_cast(dyn_cast(trivialEntry))); + + EXPECT_EQ(dyn_cast(cyclicEntry), nullptr); + EXPECT_FALSE(static_cast(dyn_cast(cyclicEntry))); + EXPECT_TRUE(static_cast(dyn_cast(cyclicEntry))); + + EXPECT_EQ(dyn_cast(cyclicEntry), opScc.getSCC(regOp)); + EXPECT_NE(dyn_cast(cyclicEntry), opScc.getSCC(outputOp)); + EXPECT_NE(dyn_cast(cyclicEntry), opScc.getSCC(unusedOp)); + + EXPECT_FALSE(static_cast(dyn_cast(trivialEntry))); + if (auto scc = dyn_cast(cyclicEntry)) { + EXPECT_EQ(scc.size(), 2u); + EXPECT_TRUE(llvm::is_contained(scc, regOp)); + EXPECT_TRUE(llvm::is_contained(scc, andOp)); + } else { + FAIL() << "expected CyclicOpSCC for cyclicEntry"; + } + + // dyn_cast_if_present additionally handles null entries gracefully. + EXPECT_EQ(dyn_cast_if_present(nullEntry), nullptr); + EXPECT_FALSE(static_cast(dyn_cast_if_present(nullEntry))); + EXPECT_FALSE(static_cast(dyn_cast_if_present(nullEntry))); +} + +} // namespace