Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions lib/NeuraDialect/Transforms/CanonicalizeLiveInPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,41 +82,78 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
// If we update a branch or conditional branch, we may introduce new live-ins
// for a block. So we need to propagate live-in values until a fixed point is
// reached.

// *************************************************************************
// For example, consider this control flow:
//
// Block A:
// %0 = constant 1
// %1 = constant 2
// br B
//
// Block B:
// br C
//
// Block C:
// %2 = add %0, %1 // %0 and %1 are live-ins for block C
// return
//
// Initial direct_live_ins analysis:
// - Block C: {%0, %1} (directly used values defined outside C)
// - Block B: {} (no direct use of external values)
//
// After propagation:
// - Block C: {%0, %1}
// - Block B: {%0, %1} (needs to pass these values to C)
//
// The transformation adds block arguments:
// Block A:
// %0 = constant 1
// %1 = constant 2
// br B(%0, %1)
//
// Block B(%b0, %b1):
// br C(%b0, %b1)
//
// Block C(%c0, %c1):
// %2 = add %c0, %c1
// return
// *************************************************************************
DenseMap<Block *, SetVector<Value>> all_live_ins = direct_live_ins;
bool changed = true;

while (changed) {
changed = false;

for (Block &pred_block : region.getBlocks()) {
if (&pred_block == &region.front()) {
for (Block &current_block : region.getBlocks()) {
if (&current_block == &region.front()) {
continue;
}

// Checks if the predecessor block has successor blocks and if they have
// Checks if current block has successor blocks and if they have
// any live-ins.
for (Block *succ_block : pred_block.getSuccessors()) {
for (Block *succ_block : current_block.getSuccessors()) {
auto succ_live_in_iter = all_live_ins.find(succ_block);
if (succ_live_in_iter == all_live_ins.end()) {
continue;
}

SetVector<Value> &succ_live_ins = succ_live_in_iter->second;
SetVector<Value> &block_live_ins = all_live_ins[&pred_block];
SetVector<Value> &block_live_ins = all_live_ins[&current_block];

unsigned old_block_live_in_size = block_live_ins.size();

// Checks if the live-in value in successor block is defined in the
// predecessor block.
// current block.
for (Value live_in : succ_live_ins) {
// If it is defined in the predecessor block, that means it is not a
// live-in value for the predecessor block. We can skip it.
// If it is defined in the current block, that means it is not a
// live-in value for the current block. We can skip it.
if (Operation *def_op = live_in.getDefiningOp()) {
if (def_op->getBlock() == &pred_block) {
if (def_op->getBlock() == &current_block) {
continue;
}
} else if (auto block_arg = dyn_cast<BlockArgument>(live_in)) {
if (block_arg.getOwner() == &pred_block) {
if (block_arg.getOwner() == &current_block) {
continue;
}
}
Expand Down Expand Up @@ -177,12 +214,12 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {

// Updates the terminators of predecessor blocks to use the new block
// arguments instead of the live-in values.
for (auto &[succ_block, live_ins] : all_live_ins) {
for (Block *pred_block : succ_block->getPredecessors()) {
for (auto &[current_block, live_ins] : all_live_ins) {
for (Block *pred_block : current_block->getPredecessors()) {
Operation *term_op = pred_block->getTerminator();

if (auto br_op = dyn_cast<neura::Br>(term_op)) {
if (br_op.getDest() == succ_block) {
if (br_op.getDest() == current_block) {
SmallVector<Value> new_operands(br_op.getOperands().begin(),
br_op.getOperands().end());
for (Value live_in : live_ins) {
Expand All @@ -199,7 +236,8 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
}
}
OpBuilder builder(br_op);
builder.create<neura::Br>(br_op.getLoc(), new_operands, succ_block);
builder.create<neura::Br>(br_op.getLoc(), new_operands,
current_block);
br_op.erase();
}
} else if (auto cond_br_op = dyn_cast<neura::CondBr>(term_op)) {
Expand All @@ -209,7 +247,7 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
SmallVector<Value> false_operands(cond_br_op.getFalseArgs().begin(),
cond_br_op.getFalseArgs().end());
// Handles the true branch.
if (cond_br_op.getTrueDest() == succ_block) {
if (cond_br_op.getTrueDest() == current_block) {
needs_update = true;
for (Value live_in : live_ins) {
Operation *def_op = live_in.getDefiningOp();
Expand All @@ -228,7 +266,7 @@ LogicalResult promoteLiveInValuesToBlockArgs(Region &region) {
}

// Handles the false branch.
if (cond_br_op.getFalseDest() == succ_block) {
if (cond_br_op.getFalseDest() == current_block) {
needs_update = true;
for (Value live_in : live_ins) {
Operation *def_op = live_in.getDefiningOp();
Expand Down