Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 80 additions & 58 deletions lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void getBlocksInPostOrder(Block *startBlock, SmallVectorImpl<Block *> &postOrder

// Creates phi nodes for all live-in values in the given block.
void createPhiNodesForBlock(
Block *block, OpBuilder &builder, DenseMap<Value, Value> &value_map,
Block *block, OpBuilder &builder,
SmallVectorImpl<std::tuple<Value, Value, Value, Block *>> &deferred_ctrl_movs) {
if (block->hasNoPredecessors()) {
// Skips phi insertion for entry block.
Expand All @@ -94,7 +94,6 @@ void createPhiNodesForBlock(
}
}

// Creates a phi node for each live-in value.
builder.setInsertionPointToStart(block);
for (Value live_in : live_ins) {
// Creates predicated type for phi node.
Expand All @@ -107,72 +106,97 @@ void createPhiNodesForBlock(
Location loc = block->empty() ?
block->getParent()->getLoc() :
block->front().getLoc();

SmallVector<Value> phi_operands;
llvm::SmallDenseSet<Operation*, 4> just_created_consumer_ops;
BlockArgument arg = dyn_cast<BlockArgument>(live_in);
// Handles the case where live_in is not a block argument.
if (!arg) {
phi_operands.push_back(live_in);
} else {
// Finds index of live_in in block arguments.
unsigned arg_index = arg.getArgNumber();
for (Block *pred : block->getPredecessors()) {
Value incoming;
Value branch_pred;
Operation *term = pred->getTerminator();

// If it's a branch or cond_br, get the value passed into this block argument
if (auto br = dyn_cast<neura::Br>(term)) {
auto args = br.getArgs();
// TODO: Following logic needs to be refactored.
for (Block *pred : block->getPredecessors()) {
Value incoming;
Value branch_pred;
Operation *term = pred->getTerminator();
// If it's a branch or cond_br, get the value passed into this block argument
if (auto br = dyn_cast<neura::Br>(term)) {
auto args = br.getArgs();
if (arg) {
unsigned arg_index = arg.getArgNumber();
assert(arg_index < args.size());
incoming = args[arg_index];
} else if (auto condBr = dyn_cast<neura::CondBr>(term)) {
Value cond = condBr.getCondition();
branch_pred = cond; // by default

OpBuilder pred_builder(condBr);
Location pred_loc = condBr.getLoc();

if (condBr.getTrueDest() == block) {
} else if (live_in.getDefiningOp()->getBlock() == pred) {
// Handles the case where live_in is not a block argument.
incoming = live_in;
} else {
// If live_in is not a block argument and not defined in the block, skips.
continue;
}
} else if (auto condBr = dyn_cast<neura::CondBr>(term)) {
Value cond = condBr.getCondition();
branch_pred = cond; // by default
OpBuilder pred_builder(condBr);
Location pred_loc = condBr.getLoc();

if (condBr.getTrueDest() == block) {
if (arg) {
auto trueArgs = condBr.getTrueArgs();
unsigned arg_index = arg.getArgNumber();
assert(arg_index < trueArgs.size());
incoming = trueArgs[arg_index];
// Keep branch_pred = cond
} else if (condBr.getFalseDest() == block) {
} else if (live_in.getDefiningOp()->getBlock() == pred) {
// Handles the case where live_in is not a block argument.
incoming = live_in;
} else {
// If live_in is not a block argument and not defined in the block, skips.
continue;
}
// Applies grant_predicate.
incoming = pred_builder.create<neura::GrantPredicateOp>(
pred_loc, incoming.getType(), incoming, cond);
just_created_consumer_ops.insert(incoming.getDefiningOp());
// Keep branch_pred = cond
} else if (condBr.getFalseDest() == block) {
if (arg) {
auto falseArgs = condBr.getFalseArgs();
unsigned arg_index = arg.getArgNumber();
assert(arg_index < falseArgs.size());
incoming = falseArgs[arg_index];
// Negates cond for false edge.
branch_pred = pred_builder.create<neura::NotOp>(pred_loc, cond.getType(), cond);

} else if (live_in.getDefiningOp()->getBlock() == pred) {
// Handles the case where live_in is not a block argument.
incoming = live_in;
} else {
llvm::errs() << "cond_br does not target block:\n" << *block << "\n";
assert(false);
// If live_in is not a block argument and not defined in the block, skips.
continue;
}

// Negates cond for false edge.
branch_pred = pred_builder.create<neura::NotOp>(pred_loc, cond.getType(), cond);
// Applies grant_predicate.
incoming = pred_builder.create<neura::GrantPredicateOp>(
pred_loc, incoming.getType(), incoming, branch_pred);

just_created_consumer_ops.insert(incoming.getDefiningOp());
} else {
llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n";
continue;
llvm::errs() << "cond_br does not target block:\n" << *block << "\n";
assert(false);
}
} else {
llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n";
continue;
}

// If the incoming value is defined in the same block, inserts a `neura.reserve`
// and defer a backward ctrl move.
if (incoming.getDefiningOp() && incoming.getDefiningOp()->getBlock() == block) {
builder.setInsertionPointToStart(block);
auto placeholder = builder.create<neura::ReserveOp>(loc, incoming.getType());
phi_operands.push_back(placeholder.getResult());
// Defers the backward ctrl move operation to be inserted after all phi operands
// are defined. Inserted:
// (real_defined_value, just_created_reserve, branch_pred, current_block).
deferred_ctrl_movs.emplace_back(
incoming, placeholder.getResult(), branch_pred, block);
} else {
phi_operands.push_back(incoming);
}
// If the incoming value is defined in the same block, inserts a `neura.reserve`
// and defer a backward ctrl move.
if (incoming.getDefiningOp() && incoming.getDefiningOp()->getBlock() == block) {
builder.setInsertionPointToStart(block);
auto placeholder = builder.create<neura::ReserveOp>(loc, incoming.getType());
phi_operands.push_back(placeholder.getResult());
// Defers the backward ctrl move operation to be inserted after all phi operands
// are defined. Inserted:
// (real_defined_value, just_created_reserve, branch_pred, current_block).
deferred_ctrl_movs.emplace_back(
incoming, placeholder.getResult(), branch_pred, block);
} else {
phi_operands.push_back(incoming);
}
// If live_in is not a block argument, we don't need to check for uniqueness.
if (!arg) {
continue;
}
}

Expand Down Expand Up @@ -202,12 +226,15 @@ void createPhiNodesForBlock(
Value single = *unique_operands.begin();
SmallVector<OpOperand *, 4> uses;
for (OpOperand &use : live_in.getUses()) {
uses.push_back(&use);
// Skip uses that were just created by the grant_predicate.
if (!just_created_consumer_ops.contains(use.getOwner())) {
uses.push_back(&use);
}
}
for (OpOperand *use : uses) {
use->set(single);
}
value_map[live_in] = single;
// No need to proceed further to create a phi node, as we have a single unique operand.
continue;
}

Expand All @@ -225,7 +252,6 @@ void createPhiNodesForBlock(
for (OpOperand *use : uses_to_be_replaced) {
use->set(phi_op.getResult());
}
value_map[live_in] = phi_op.getResult();
}
}

Expand Down Expand Up @@ -265,13 +291,10 @@ struct TransformCtrlToDataFlowPass
DenseSet<Block *> visited;
getBlocksInPostOrder(&func.getBody().front(), postOrder, visited);

// Value mapping for phi node creation.
DenseMap<Value, Value> value_map;

// Process blocks bottom-up
for (Block *block : postOrder) {
// Creates phi nodes for live-ins.
createPhiNodesForBlock(block, builder, value_map, deferred_ctrl_movs);
createPhiNodesForBlock(block, builder, deferred_ctrl_movs);
}

// Flattens blocks into the entry block.
Expand Down Expand Up @@ -332,7 +355,6 @@ struct TransformCtrlToDataFlowPass

mov_builder.create<neura::CtrlMovOp>(insert_loc, guarded_val, placeholder);
}

}
};
} // namespace
Expand Down
13 changes: 8 additions & 5 deletions test/neura/ctrl/branch_no_arg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ func.func @test(%in: i64) -> f32 {
// CHECK-NEXT: "neura.return"(%10) : (!neura.data<f32, i1>) -> ()
// CHECK-NEXT: }

// FIXME: Seems the bb1 is not depending on condition's NOT.
// CTRL2DATA: func.func @test(%arg0: i64) -> f32 attributes {accelerator = "neura"} {
// CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> !neura.data<i64, i1>
// CTRL2DATA-NEXT: %1 = "neura.constant"() <{predicate = true, value = 1.000000e+00 : f32}> : () -> !neura.data<f32, i1>
Expand All @@ -65,8 +64,12 @@ func.func @test(%in: i64) -> f32 {
// CTRL2DATA-NEXT: %10 = "neura.grant_once"(%9) : (!neura.data<i1, i1>) -> !neura.data<i1, i1>
// CTRL2DATA-NEXT: %11 = neura.grant_predicate %6, %10 : !neura.data<f32, i1>, !neura.data<i1, i1> -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %12 = neura.grant_predicate %8, %10 : !neura.data<f32, i1>, !neura.data<i1, i1> -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %13 = "neura.fadd"(%2, %4) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %14 = "neura.fmul"(%11, %12) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %15 = "neura.phi"(%13, %14) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: "neura.return"(%15) : (!neura.data<f32, i1>) -> ()
// CTRL2DATA-NEXT: %13 = "neura.not"(%10) : (!neura.data<i1, i1>) -> !neura.data<i1, i1>
// CTRL2DATA-NEXT: %14 = neura.grant_predicate %2, %13 : !neura.data<f32, i1>, !neura.data<i1, i1> -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %15 = "neura.not"(%10) : (!neura.data<i1, i1>) -> !neura.data<i1, i1>
// CTRL2DATA-NEXT: %16 = neura.grant_predicate %4, %15 : !neura.data<f32, i1>, !neura.data<i1, i1> -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %17 = "neura.fadd"(%14, %16) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %18 = "neura.fmul"(%11, %12) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %19 = "neura.phi"(%17, %18) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: "neura.return"(%19) : (!neura.data<f32, i1>) -> ()
// CTRL2DATA-NEXT: }