diff --git a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp index 54642b13..367456bc 100644 --- a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp +++ b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp @@ -70,7 +70,7 @@ void getBlocksInPostOrder(Block *startBlock, SmallVectorImpl &postOrder // Creates phi nodes for all live-in values in the given block. void createPhiNodesForBlock( - Block *block, OpBuilder &builder, DenseMap &value_map, + Block *block, OpBuilder &builder, SmallVectorImpl> &deferred_ctrl_movs) { if (block->hasNoPredecessors()) { // Skips phi insertion for entry block. @@ -94,7 +94,6 @@ void createPhiNodesForBlock( } } - // Creates a phi node for each live-in value. builder.setInsertionPointToStart(block); for (Value live_in : live_ins) { // Creates predicated type for phi node. @@ -107,72 +106,97 @@ void createPhiNodesForBlock( Location loc = block->empty() ? block->getParent()->getLoc() : block->front().getLoc(); - SmallVector phi_operands; + llvm::SmallDenseSet just_created_consumer_ops; BlockArgument arg = dyn_cast(live_in); - // Handles the case where live_in is not a block argument. - if (!arg) { - phi_operands.push_back(live_in); - } else { - // Finds index of live_in in block arguments. - unsigned arg_index = arg.getArgNumber(); - for (Block *pred : block->getPredecessors()) { - Value incoming; - Value branch_pred; - Operation *term = pred->getTerminator(); - - // If it's a branch or cond_br, get the value passed into this block argument - if (auto br = dyn_cast(term)) { - auto args = br.getArgs(); + // TODO: Following logic needs to be refactored. + for (Block *pred : block->getPredecessors()) { + Value incoming; + Value branch_pred; + Operation *term = pred->getTerminator(); + // If it's a branch or cond_br, get the value passed into this block argument + if (auto br = dyn_cast(term)) { + auto args = br.getArgs(); + if (arg) { + unsigned arg_index = arg.getArgNumber(); assert(arg_index < args.size()); incoming = args[arg_index]; - } else if (auto condBr = dyn_cast(term)) { - Value cond = condBr.getCondition(); - branch_pred = cond; // by default - - OpBuilder pred_builder(condBr); - Location pred_loc = condBr.getLoc(); - - if (condBr.getTrueDest() == block) { + } else if (live_in.getDefiningOp()->getBlock() == pred) { + // Handles the case where live_in is not a block argument. + incoming = live_in; + } else { + // If live_in is not a block argument and not defined in the block, skips. + continue; + } + } else if (auto condBr = dyn_cast(term)) { + Value cond = condBr.getCondition(); + branch_pred = cond; // by default + OpBuilder pred_builder(condBr); + Location pred_loc = condBr.getLoc(); + + if (condBr.getTrueDest() == block) { + if (arg) { auto trueArgs = condBr.getTrueArgs(); + unsigned arg_index = arg.getArgNumber(); assert(arg_index < trueArgs.size()); incoming = trueArgs[arg_index]; - // Keep branch_pred = cond - } else if (condBr.getFalseDest() == block) { + } else if (live_in.getDefiningOp()->getBlock() == pred) { + // Handles the case where live_in is not a block argument. + incoming = live_in; + } else { + // If live_in is not a block argument and not defined in the block, skips. + continue; + } + // Applies grant_predicate. + incoming = pred_builder.create( + pred_loc, incoming.getType(), incoming, cond); + just_created_consumer_ops.insert(incoming.getDefiningOp()); + // Keep branch_pred = cond + } else if (condBr.getFalseDest() == block) { + if (arg) { auto falseArgs = condBr.getFalseArgs(); + unsigned arg_index = arg.getArgNumber(); assert(arg_index < falseArgs.size()); incoming = falseArgs[arg_index]; - // Negates cond for false edge. - branch_pred = pred_builder.create(pred_loc, cond.getType(), cond); - + } else if (live_in.getDefiningOp()->getBlock() == pred) { + // Handles the case where live_in is not a block argument. + incoming = live_in; } else { - llvm::errs() << "cond_br does not target block:\n" << *block << "\n"; - assert(false); + // If live_in is not a block argument and not defined in the block, skips. + continue; } - + // Negates cond for false edge. + branch_pred = pred_builder.create(pred_loc, cond.getType(), cond); // Applies grant_predicate. incoming = pred_builder.create( pred_loc, incoming.getType(), incoming, branch_pred); - + just_created_consumer_ops.insert(incoming.getDefiningOp()); } else { - llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n"; - continue; + llvm::errs() << "cond_br does not target block:\n" << *block << "\n"; + assert(false); } + } else { + llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n"; + continue; + } - // If the incoming value is defined in the same block, inserts a `neura.reserve` - // and defer a backward ctrl move. - if (incoming.getDefiningOp() && incoming.getDefiningOp()->getBlock() == block) { - builder.setInsertionPointToStart(block); - auto placeholder = builder.create(loc, incoming.getType()); - phi_operands.push_back(placeholder.getResult()); - // Defers the backward ctrl move operation to be inserted after all phi operands - // are defined. Inserted: - // (real_defined_value, just_created_reserve, branch_pred, current_block). - deferred_ctrl_movs.emplace_back( - incoming, placeholder.getResult(), branch_pred, block); - } else { - phi_operands.push_back(incoming); - } + // If the incoming value is defined in the same block, inserts a `neura.reserve` + // and defer a backward ctrl move. + if (incoming.getDefiningOp() && incoming.getDefiningOp()->getBlock() == block) { + builder.setInsertionPointToStart(block); + auto placeholder = builder.create(loc, incoming.getType()); + phi_operands.push_back(placeholder.getResult()); + // Defers the backward ctrl move operation to be inserted after all phi operands + // are defined. Inserted: + // (real_defined_value, just_created_reserve, branch_pred, current_block). + deferred_ctrl_movs.emplace_back( + incoming, placeholder.getResult(), branch_pred, block); + } else { + phi_operands.push_back(incoming); + } + // If live_in is not a block argument, we don't need to check for uniqueness. + if (!arg) { + continue; } } @@ -202,12 +226,15 @@ void createPhiNodesForBlock( Value single = *unique_operands.begin(); SmallVector uses; for (OpOperand &use : live_in.getUses()) { - uses.push_back(&use); + // Skip uses that were just created by the grant_predicate. + if (!just_created_consumer_ops.contains(use.getOwner())) { + uses.push_back(&use); + } } for (OpOperand *use : uses) { use->set(single); } - value_map[live_in] = single; + // No need to proceed further to create a phi node, as we have a single unique operand. continue; } @@ -225,7 +252,6 @@ void createPhiNodesForBlock( for (OpOperand *use : uses_to_be_replaced) { use->set(phi_op.getResult()); } - value_map[live_in] = phi_op.getResult(); } } @@ -265,13 +291,10 @@ struct TransformCtrlToDataFlowPass DenseSet visited; getBlocksInPostOrder(&func.getBody().front(), postOrder, visited); - // Value mapping for phi node creation. - DenseMap value_map; - // Process blocks bottom-up for (Block *block : postOrder) { // Creates phi nodes for live-ins. - createPhiNodesForBlock(block, builder, value_map, deferred_ctrl_movs); + createPhiNodesForBlock(block, builder, deferred_ctrl_movs); } // Flattens blocks into the entry block. @@ -332,7 +355,6 @@ struct TransformCtrlToDataFlowPass mov_builder.create(insert_loc, guarded_val, placeholder); } - } }; } // namespace diff --git a/test/neura/ctrl/branch_no_arg.mlir b/test/neura/ctrl/branch_no_arg.mlir index 5d759c6d..5aa9808f 100644 --- a/test/neura/ctrl/branch_no_arg.mlir +++ b/test/neura/ctrl/branch_no_arg.mlir @@ -50,7 +50,6 @@ func.func @test(%in: i64) -> f32 { // CHECK-NEXT: "neura.return"(%10) : (!neura.data) -> () // CHECK-NEXT: } -// FIXME: Seems the bb1 is not depending on condition's NOT. // CTRL2DATA: func.func @test(%arg0: i64) -> f32 attributes {accelerator = "neura"} { // CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> !neura.data // CTRL2DATA-NEXT: %1 = "neura.constant"() <{predicate = true, value = 1.000000e+00 : f32}> : () -> !neura.data @@ -65,8 +64,12 @@ func.func @test(%in: i64) -> f32 { // CTRL2DATA-NEXT: %10 = "neura.grant_once"(%9) : (!neura.data) -> !neura.data // CTRL2DATA-NEXT: %11 = neura.grant_predicate %6, %10 : !neura.data, !neura.data -> !neura.data // CTRL2DATA-NEXT: %12 = neura.grant_predicate %8, %10 : !neura.data, !neura.data -> !neura.data -// CTRL2DATA-NEXT: %13 = "neura.fadd"(%2, %4) : (!neura.data, !neura.data) -> !neura.data -// CTRL2DATA-NEXT: %14 = "neura.fmul"(%11, %12) : (!neura.data, !neura.data) -> !neura.data -// CTRL2DATA-NEXT: %15 = "neura.phi"(%13, %14) : (!neura.data, !neura.data) -> !neura.data -// CTRL2DATA-NEXT: "neura.return"(%15) : (!neura.data) -> () +// CTRL2DATA-NEXT: %13 = "neura.not"(%10) : (!neura.data) -> !neura.data +// CTRL2DATA-NEXT: %14 = neura.grant_predicate %2, %13 : !neura.data, !neura.data -> !neura.data +// CTRL2DATA-NEXT: %15 = "neura.not"(%10) : (!neura.data) -> !neura.data +// CTRL2DATA-NEXT: %16 = neura.grant_predicate %4, %15 : !neura.data, !neura.data -> !neura.data +// CTRL2DATA-NEXT: %17 = "neura.fadd"(%14, %16) : (!neura.data, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: %18 = "neura.fmul"(%11, %12) : (!neura.data, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: %19 = "neura.phi"(%17, %18) : (!neura.data, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: "neura.return"(%19) : (!neura.data) -> () // CTRL2DATA-NEXT: }