Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ build/
!arch_spec_example.yaml
!architecture.yaml
/test/**/*.asm
/test/**/*.json
/test/**/.sh
.lit_test_times.txt
lit.cfg
*.dot
Expand Down
280 changes: 272 additions & 8 deletions lib/NeuraDialect/Transforms/CanonicalizeLiveInPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <string>

Expand All @@ -17,24 +21,275 @@ using namespace mlir;
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {
LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
struct DirectDataflowLiveIn {
// The live-in value.
Value value;
// The block where the live-in value is defined.
Block *defining_block;
// The block where the live-in value is used.
Block *using_block;
};

bool pathsCrossConditionalBranch(Block *defining_block, Block *using_block,
DominanceInfo &dom_info,
PostDominanceInfo &post_dom_info) {

// 1. defining_block must dominate using_block.
// This ensures that all paths to using_block go through defining_block.
if (!dom_info.dominates(defining_block, using_block)) {
return false;
}

// 2. using_block must post-dominate defining_block.
// This ensures that all paths from defining_block eventually reach
// using_block.
if (!post_dom_info.postDominates(using_block, defining_block)) {
return false;
}

// 3. If defining_block and using_block are the same, or using_block is a
// direct successor of defining_block, then there are no conditional branches
// on the path.
if (defining_block == using_block) {
return false;
}

if (defining_block == &defining_block->getParent()->front()) {
// If defining_block is the entry block of the region, it is not considered
// as crossing a conditional branch.
// Avoids violating assertions in TransformCtrlToDataFlowPass.cpp.
return false;
}

// 4. Checks if using_block is a direct successor (no intermediate blocks) of
// defining_block.
for (Block *succ : defining_block->getSuccessors()) {
if (succ == using_block) {
Operation *term_op = defining_block->getTerminator();
// If the terminator is an unconditional branch, then no conditional
// branch exists on the path.
if (isa<neura::Br>(term_op)) {
return false;
}
// If it is a conditional branch, but both targets are using_block, it is
// also considered no real branch.
if (auto cond_br = dyn_cast<neura::CondBr>(term_op)) {
if (cond_br.getTrueDest() == using_block &&
cond_br.getFalseDest() == using_block) {
return false;
}
}
}
}

// 5. Finds any conditional branch on the paths from defining_block to
// using_block.
bool found_conditional_branch = false;
Block *conditional_branch_block = nullptr;

Region *region = defining_block->getParent();
for (Block &block : region->getBlocks()) {
if (&block == defining_block || &block == using_block) {
continue;
}

// Checks if this block is on the path from defining_block to using_block.
if (dom_info.dominates(defining_block, &block) &&
dom_info.dominates(&block, using_block)) {

// Checks if this block's terminator is a conditional branch.
Operation *term_op = block.getTerminator();
if (auto cond_br = dyn_cast<neura::CondBr>(term_op)) {
Block *true_dest = cond_br.getTrueDest();
Block *false_dest = cond_br.getFalseDest();

// Ensures both branch targets are different (true conditional branch)
if (true_dest != false_dest) {
found_conditional_branch = true;
conditional_branch_block = &block;
break;
}
}
}
}

// 6. Checks the terminator of defining_block itself.
Operation *defining_term = defining_block->getTerminator();
if (auto cond_br = dyn_cast<neura::CondBr>(defining_term)) {
Block *true_dest = cond_br.getTrueDest();
Block *false_dest = cond_br.getFalseDest();
if (true_dest != false_dest) {
found_conditional_branch = true;
conditional_branch_block = defining_block;
}
}

if (!found_conditional_branch) {
return false;
}

// 7. Key Constraint: Verifies that BOTH branches eventually reach using_block
// WITHOUT creating a loop back to conditional_branch_block or earlier.
assert(conditional_branch_block &&
"Must have found a conditional branch block");

Operation *cond_term = conditional_branch_block->getTerminator();
auto cond_br = dyn_cast<neura::CondBr>(cond_term);
assert(cond_br && "Must be a conditional branch");

Block *true_dest = cond_br.getTrueDest();
Block *false_dest = cond_br.getFalseDest();

// Checks loop back edge: If either branch goes back to the conditional branch
// block or any of its dominators, it creates a loop.
if (true_dest == conditional_branch_block ||
dom_info.dominates(true_dest, conditional_branch_block)) {
llvm::errs()
<< "[CanoLiveIn] True branch creates a back edge (loop pattern)\n";
return false;
}

if (false_dest == conditional_branch_block ||
dom_info.dominates(false_dest, conditional_branch_block)) {
llvm::errs()
<< "[CanoLiveIn] False branch creates a back edge (loop pattern)\n";
return false;
}

// Checks if both branches can reach using_block.
bool true_reaches = (true_dest == using_block);
if (!true_reaches) {
if (dom_info.dominates(true_dest, using_block)) {
true_reaches = true;
} else {
for (Block *pred : using_block->getPredecessors()) {
if (pred == true_dest || dom_info.dominates(true_dest, pred)) {
true_reaches = true;
break;
}
}
}
}

bool false_reaches = (false_dest == using_block);
if (!false_reaches) {
if (dom_info.dominates(false_dest, using_block)) {
false_reaches = true;
} else {
for (Block *pred : using_block->getPredecessors()) {
if (pred == false_dest || dom_info.dominates(false_dest, pred)) {
false_reaches = true;
break;
}
}
}
}

if (!true_reaches || !false_reaches) {
return false;
}

return true;
}

DenseMap<Block *, SmallVector<DirectDataflowLiveIn>>
identifyDirectDataflowLiveIns(Region &region, DominanceInfo &dom_info,
PostDominanceInfo &post_dom_info) {
DenseMap<Block *, SmallVector<DirectDataflowLiveIn>>
using_block_to_direct_dataflow_live_ins;
for (Block &block : region.getBlocks()) {
// Skips the entry block.
if (&block == &region.front()) {
continue;
}

// Collects direct live-in values for the block.
SetVector<Value> live_ins;
for (Operation &op : block.getOperations()) {
for (Value operand : op.getOperands()) {
// If the operand is defined in another block, it is a live-in value.
if (auto block_arg = dyn_cast<BlockArgument>(operand)) {
if (block_arg.getOwner() != &block) {
live_ins.insert(operand);
}
} else {
Operation *def_op = operand.getDefiningOp();
if (def_op && def_op->getBlock() != &block) {
live_ins.insert(operand);
}
}
}
}

// Checks each live-in value to see if it has direct dataflow dependency.
for (Value live_in : live_ins) {
Block *defining_block = nullptr;

if (auto block_arg = dyn_cast<BlockArgument>(live_in)) {
defining_block = block_arg.getOwner();
} else {
Operation *def_op = live_in.getDefiningOp();
if (def_op) {
defining_block = def_op->getBlock();
}
}

if (!defining_block) {
continue;
}

if (pathsCrossConditionalBranch(defining_block, &block, dom_info,
post_dom_info)) {
DirectDataflowLiveIn direct_dataflow_live_in;
direct_dataflow_live_in.value = live_in;
direct_dataflow_live_in.defining_block = defining_block;
direct_dataflow_live_in.using_block = &block;

using_block_to_direct_dataflow_live_ins[&block].push_back(
direct_dataflow_live_in);
}
}
}
return using_block_to_direct_dataflow_live_ins;
}

LogicalResult promoteLiveInValuesToBlockArgs(Region &region,
DominanceInfo &dom_info,
PostDominanceInfo &post_dom_info) {
if (region.empty()) {
return success();
}

DenseMap<Block *, SmallVector<DirectDataflowLiveIn>>
direct_dataflow_live_ins =
identifyDirectDataflowLiveIns(region, dom_info, post_dom_info);

// Maps each block to its direct dataflow live-in values.
DenseMap<Block *, SetVector<Value>> direct_dataflow_live_in_values;
for (auto &[block, dataflow_live_ins] : direct_dataflow_live_ins) {
for (auto &dataflow_live_in : dataflow_live_ins) {
direct_dataflow_live_in_values[block].insert(dataflow_live_in.value);
}
}

// Collects direct live-in values for each block in the region.
// Without considering the transitive dependencies.
DenseMap<Block *, SetVector<Value>> direct_live_ins;

Block &entry_block = region.front();
// Initializes the direct live-ins for each block.
for (Block &block : region.getBlocks()) {
if (&block == &entry_block) {
if (&block == &region.front()) {
continue;
}

SetVector<Value> live_ins;
for (Operation &op : block.getOperations()) {
for (Value operand : op.getOperands()) {
// If the operand is a direct dataflow live-in value, skip it.
if (direct_dataflow_live_in_values[&block].contains(operand)) {
continue;
}

// If the operand is defined in another block, it is a live-in value.
if (auto block_arg = dyn_cast<BlockArgument>(operand)) {
if (block_arg.getOwner() != &block) {
Expand All @@ -54,9 +309,9 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
}
}

// If we update a branch or conditional branch, we may introduce new live-ins
// for a block. So we need to propagate live-in values until a fixed point is
// reached.
// If we update a branch or conditional branch, we may introduce new
// live-ins for a block. So we need to propagate live-in values until a
// fixed point is reached.

// *************************************************************************
// For example, consider this control flow:
Expand Down Expand Up @@ -119,6 +374,12 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
// Checks if the live-in value in successor block is defined in the
// current block.
for (Value live_in : succ_live_ins) {
// If it is a direct dataflow live-in value for the successor block,
// we skip it.
if (direct_dataflow_live_in_values[succ_block].contains(live_in)) {
continue;
}

// If it is defined in the current block, that means it is not a
// live-in value for the current block. We can skip it.
if (Operation *def_op = live_in.getDefiningOp()) {
Expand Down Expand Up @@ -271,7 +532,6 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
}
}
}

return success();
}

Expand Down Expand Up @@ -313,7 +573,11 @@ struct CanonicalizeLiveInPass
return;
}

if (failed(promoteLiveInValuesToBlockArgs(*region))) {
DominanceInfo dom_info(op);
PostDominanceInfo post_dom_info(op);

if (failed(promoteLiveInValuesToBlockArgs(*region, dom_info,
post_dom_info))) {
signalPassFailure();
return;
}
Expand Down
16 changes: 8 additions & 8 deletions lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,19 +488,19 @@ void transformControlFlowToDataFlow(Region &region, ControlFlowInfo &ctrl_info,
// Sorts blocks by reverse post-order traversal to maintain SSA dominance.
Block *entry_block = &region.front();
SmallVector<Block *> blocks_to_flatten;

// Uses reverse post-order: visit successors before predecessors.
// This ensures that when we move blocks, definitions come before uses.
llvm::SetVector<Block *> visited;
// Post-order traversal result, used for sorting blocks.
SmallVector<Block *> po_order;
std::function<void(Block*)> po_traverse = [&](Block *block) {

std::function<void(Block *)> po_traverse = [&](Block *block) {
// Records visited block and skips if already visited.
if (!visited.insert(block)) {
return;
}

// Visits successors first (post-order).
Operation *terminator = block->getTerminator();
if (auto br = dyn_cast<neura::Br>(terminator)) {
Expand All @@ -509,16 +509,16 @@ void transformControlFlowToDataFlow(Region &region, ControlFlowInfo &ctrl_info,
po_traverse(cond_br.getTrueDest());
po_traverse(cond_br.getFalseDest());
}

// Adds to post-order.
po_order.push_back(block);
};

po_traverse(entry_block);

// Reverses post-order for forward traversal.
SmallVector<Block *> rpo_order(po_order.rbegin(), po_order.rend());

// Collects non-entry blocks in RPO order.
for (Block *block : rpo_order) {
if (block != entry_block) {
Expand Down
Loading