diff --git a/include/NeuraDialect/NeuraOps.td b/include/NeuraDialect/NeuraOps.td index 01e54159..910da7c5 100644 --- a/include/NeuraDialect/NeuraOps.td +++ b/include/NeuraDialect/NeuraOps.td @@ -210,11 +210,11 @@ def Neura_PhiOp : Op { neura.ctrl_mov %next to %v // Connect next iteration }]; - let arguments = (ins AnyType:$init_val, AnyType:$loop_val); + let arguments = (ins Variadic:$inputs); let results = (outs AnyType:$result); // Explicitly specify types for operands in the assembly format - let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `:` type($result)"; + // let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `,` type($result)"; } // Control movement extending base move but with different signature. diff --git a/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp b/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp index 74f8b93a..e894c2d0 100644 --- a/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp +++ b/lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp @@ -13,51 +13,6 @@ using namespace mlir; #include "NeuraDialect/NeuraPasses.h.inc" namespace { -struct applyPredicatedDataType : public RewritePattern { - applyPredicatedDataType(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} - - LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - llvm::errs() << "Processing op: " << *op << "\n"; - - // Skips if not a Neura op or already using predicated values. - if (op->getDialect()->getNamespace() != "neura") { - llvm::errs() << "Skipping non-Neura op\n"; - return failure(); - } - - if (llvm::any_of(op->getResultTypes(), - [](Type t) { return mlir::isa(t); })) { - llvm::errs() << "Skipping already predicated op\n"; - return failure(); - } - - // Converts result types to predicated form. - SmallVector newResults; - for (Type t : op->getResultTypes()) { - auto predicatedTy = mlir::neura::PredicatedValue::get( - op->getContext(), - t, - rewriter.getI1Type()); - newResults.push_back(predicatedTy); - } - - // Clones the operation with new result types. - OperationState state(op->getLoc(), op->getName()); - state.addOperands(op->getOperands()); - state.addTypes(newResults); - state.addAttributes(op->getAttrs()); - Operation *newOp = rewriter.create(state); - - // Replaces the old op with the new one. - rewriter.replaceOp(op, newOp->getResults()); - llvm::errs() << "Converted op to predicated form: " << *newOp << "\n"; - if (!newResults.empty()) { - assert(false); - } - return success(); - } -}; struct LeveragePredicatedValuePass : public PassWrapper> { @@ -77,7 +32,26 @@ struct LeveragePredicatedValuePass // Processes each function. module.walk([&](func::FuncOp func) { - // Get operations in topological order (operands before users) + // Converts block argument types to predicated values. + func.walk([&](Block *block) { + // skips the entry (first) block of the function. + if (block == &block->getParent()->front()) { + return; + } + for (BlockArgument arg : block->getArguments()) { + Type origType = arg.getType(); + + // Avoid double-wrapping if already predicated + if (llvm::isa(origType)) + continue; + + auto predicated_type = neura::PredicatedValue::get( + func.getContext(), origType, IntegerType::get(func.getContext(), 1)); + arg.setType(predicated_type); + } + }); + + // Gets operations in topological order (operands before users). SmallVector orderedOps; getOperationsInTopologicalOrder(func, orderedOps); @@ -122,11 +96,8 @@ struct LeveragePredicatedValuePass // Converts a single operation to use predicated values. LogicalResult applyPredicatedDataType(Operation *op) { - llvm::errs() << "Processing op: " << *op << "\n"; - // Skips if not a Neura op. if (op->getDialect()->getNamespace() != "neura") { - llvm::errs() << "Skipping non-Neura op\n"; return success(); } @@ -141,11 +112,11 @@ struct LeveragePredicatedValuePass OpBuilder builder(op); SmallVector newResults; for (Type t : op->getResultTypes()) { - auto predicatedTy = mlir::neura::PredicatedValue::get( + auto predicated_type = mlir::neura::PredicatedValue::get( op->getContext(), t, builder.getI1Type()); - newResults.push_back(predicatedTy); + newResults.push_back(predicated_type); } // Clones with new result types. diff --git a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp index 25461697..f66ddaa1 100644 --- a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp +++ b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp @@ -1,6 +1,7 @@ #include "Common/AcceleratorAttrs.h" #include "NeuraDialect/NeuraDialect.h" #include "NeuraDialect/NeuraOps.h" +#include "NeuraDialect/NeuraTypes.h" #include "NeuraDialect/NeuraPasses.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -11,136 +12,114 @@ using namespace mlir; #define GEN_PASS_DEF_TransformCtrlToDataFlow #include "NeuraDialect/NeuraPasses.h.inc" -// Processes a block recursively, cloning its operations into the entry block. -void processBlockRecursively(Block *block, Block &entry_block, Value predicate, OpBuilder &builder, - SmallVector &results, DenseSet &visited_blocks, - DenseMap &arg_mapping, - DenseMap &value_mapping) { - // Checks if the block has already been visited. - if (visited_blocks.contains(block)) { - llvm::errs() << "Skipping already visited block:\n"; - block->dump(); +// Returns blocks in post-order traversal order. +void getBlocksInPostOrder(Block *startBlock, SmallVectorImpl &postOrder, + DenseSet &visited) { + if (!visited.insert(startBlock).second) return; - } - - // Marks the block as visited. - visited_blocks.insert(block); - llvm::errs() << "Processing block:\n"; - block->dump(); + // Visits successors first. + for (Block *succ : startBlock->getSuccessors()) + getBlocksInPostOrder(succ, postOrder, visited); - // Handle. block arguments first. - for (BlockArgument arg : block->getArguments()) { - llvm::errs() << "Processing block argument: " << arg << "\n"; - - // Checks if we already have a mapping for this argument. - if (auto mapped = arg_mapping.lookup(arg)) { - llvm::errs() << "Found existing mapping for argument\n"; - continue; - } + // Adds current block to post-order sequence. + postOrder.push_back(startBlock); +} - builder.setInsertionPointToEnd(&entry_block); - // Creates a new constant operation with zero value and true predicate. - OperationState state(arg.getLoc(), neura::ConstantOp::getOperationName()); - state.addAttribute("value", builder.getZeroAttr(arg.getType())); - state.addAttribute("predicate", builder.getBoolAttr(true)); - state.addTypes(arg.getType()); - Value false_val = builder.create(state)->getResult(0); - - llvm::errs() << "Creating false_val: \n"; - false_val.dump(); - auto sel = builder.create( - arg.getLoc(), arg.getType(), arg, false_val, predicate); - - llvm::errs() << "Created sel operation for argument:\n"; - sel->dump(); - - // Stores mapping. - arg_mapping.try_emplace(arg, sel.getResult()); - value_mapping[arg] = sel.getResult(); - results.push_back(sel.getResult()); +// Creates phi nodes for all live-in values in the given block. +void createPhiNodesForBlock(Block *block, OpBuilder &builder, + DenseMap &value_map) { + if (block->hasNoPredecessors()) { + // Skips phi insertion for entry block. + return; } - // Processes operations. - SmallVector ops_to_process; + // Collects all live-in values. + std::vector live_ins; for (Operation &op : *block) { - ops_to_process.push_back(&op); - } - - for (Operation *op : ops_to_process) { - llvm::errs() << "Processing operation:\n"; - op->dump(); - - if (op->hasTrait()) { - if (auto br = dyn_cast(op)) { - llvm::errs() << "Found unconditional branch\n"; - for (Value operand : br.getOperands()) { - if (auto mapped = value_mapping.lookup(operand)) { - results.push_back(mapped); - } else { - results.push_back(operand); + for (Value operand : op.getOperands()) { + // Identifies operands defined in other blocks. + if (operand.getDefiningOp() && + operand.getDefiningOp()->getBlock() != block) { + // Checks if the live-in is a block argument. SSA form forces this rule. + bool found_in_block_argument = false; + for (BlockArgument arg : block->getArguments()) { + if (arg == operand) { + found_in_block_argument = true; + break; } } - } else if (auto cond_br = dyn_cast(op)) { - llvm::errs() << "Found conditional branch\n"; - Value cond = cond_br.getCondition(); - auto not_cond = builder.create(cond_br.getLoc(), cond.getType(), cond); - - SmallVector true_results, false_results; - processBlockRecursively(cond_br.getTrueDest(), entry_block, cond, - builder, true_results, visited_blocks, arg_mapping, value_mapping); - processBlockRecursively(cond_br.getFalseDest(), entry_block, not_cond.getResult(), - builder, false_results, visited_blocks, arg_mapping, value_mapping); - - builder.setInsertionPointToEnd(&entry_block); - for (auto [true_result, false_result] : llvm::zip(true_results, false_results)) { - auto sel = builder.create( - op->getLoc(), true_result.getType(), true_result, false_result, cond); - value_mapping[sel.getResult()] = sel.getResult(); - results.push_back(sel.getResult()); - } - } else if (auto ret = dyn_cast(op)) { - llvm::errs() << "Found Return\n"; - for (Value operand : ret.getOperands()) { - if (auto mapped = value_mapping.lookup(operand)) { - results.push_back(mapped); - } else { - results.push_back(operand); - } - } - } else { - // Handle other terminators if needed - llvm::errs() << "Found unexpected terminator operation:\n"; - op->dump(); - assert(false && "Unexpected terminator operation in block"); + assert(found_in_block_argument && "Live-in value defined outside the block must be passed as a block argument"); + live_ins.push_back(operand); } - } - builder.setInsertionPointToEnd(&entry_block); - Operation *cloned_op = builder.clone(*op); - - // Replaces operands with mapped values. - for (unsigned i = 0; i < cloned_op->getNumOperands(); ++i) { - Value operand = cloned_op->getOperand(i); - if (auto mapped = value_mapping.lookup(operand)) { - cloned_op->setOperand(i, mapped); + // Collects all block arguments. + if (auto blockArg = llvm::dyn_cast(operand)) { + live_ins.push_back(operand); } } + } - if (!cloned_op->hasTrait()) { - cloned_op->insertOperands(cloned_op->getNumOperands(), predicate); + // Creates a phi node for each live-in value. + builder.setInsertionPointToStart(block); + for (Value live_in : live_ins) { + // Creates predicated type for phi node. + Type live_in_type = live_in.getType(); + Type predicated_type = isa(live_in_type) + ? live_in_type + : neura::PredicatedValue::get(builder.getContext(), live_in_type, builder.getI1Type()); + + // Uses the location from the first operation in the block or block's parent operation. + Location loc = block->empty() ? + block->getParent()->getLoc() : + block->front().getLoc(); + + SmallVector phi_operands; + + // Finds index of live_in in block arguments. + auto arg = dyn_cast(live_in); + assert(arg && "Expected live_in to be a block argument"); + unsigned arg_index = arg.getArgNumber(); + + for (Block *pred : block->getPredecessors()) { + Value incoming; + Operation *term = pred->getTerminator(); + + // If it's a branch or cond_br, get the value passed into this block argument + if (auto br = dyn_cast(term)) { + auto args = br.getArgs(); + assert(arg_index < args.size()); + incoming = args[arg_index]; + } else if (auto condBr = dyn_cast(term)) { + if (condBr.getTrueDest() == block) { + auto trueArgs = condBr.getTrueArgs(); + assert(arg_index < trueArgs.size()); + incoming = trueArgs[arg_index]; + } else if (condBr.getFalseDest() == block) { + auto falseArgs = condBr.getFalseArgs(); + assert(arg_index < falseArgs.size()); + incoming = falseArgs[arg_index]; + } else { + llvm::errs() << "cond_br does not target block:\n" << *block << "\n"; + continue; + } + } else { + llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n"; + continue; + } + phi_operands.push_back(incoming); } - // Stores mappings and results. - for (unsigned i = 0; i < op->getNumResults(); ++i) { - Value orig_result = op->getResult(i); - Value new_result = cloned_op->getResult(i); - value_mapping[orig_result] = new_result; - results.push_back(new_result); - } + // Use default value if no incoming values found + assert(!phi_operands.empty()); + + // Create the phi node with dynamic number of operands + auto phi_op = builder.create(loc, predicated_type, phi_operands); + + // Replace block argument use with the phi result + arg.replaceAllUsesWith(phi_op.getResult()); + value_map[live_in] = phi_op.getResult(); } - llvm::errs() << "[cheng] after processing entry_block:\n"; - entry_block.dump(); } namespace { @@ -150,7 +129,7 @@ struct TransformCtrlToDataFlowPass StringRef getArgument() const override { return "transform-ctrl-to-data-flow"; } StringRef getDescription() const override { - return "Flattens control flow into predicated linear SSA for Neura dialect."; + return "Transforms control flow into data flow using predicated execution"; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -161,85 +140,57 @@ struct TransformCtrlToDataFlowPass ModuleOp module = getOperation(); module.walk([&](func::FuncOp func) { - llvm::errs() << "Processing function: "; - func.dump(); - - if (!func->hasAttr(mlir::accel::kAcceleratorAttr)) - return; - - auto target = func->getAttrOfType(mlir::accel::kAcceleratorAttr); - if (!target || target.getValue() != mlir::accel::kNeuraTarget) - return; - - Block &entry_block = func.getBody().front(); - llvm::errs() << "Entry block before processing:\n"; - entry_block.dump(); - - OpBuilder builder(&entry_block, entry_block.begin()); - - // Check for terminator - Operation *terminator = nullptr; - if (!entry_block.empty()) { - terminator = &entry_block.back(); + // Get blocks in post-order + SmallVector postOrder; + DenseSet visited; + getBlocksInPostOrder(&func.getBody().front(), postOrder, visited); + + // Value mapping for phi node creation. + DenseMap value_map; + OpBuilder builder(func.getContext()); + + // Process blocks bottom-up + for (Block *block : postOrder) { + // Creates phi nodes for live-ins. + createPhiNodesForBlock(block, builder, value_map); } - auto cond_br = dyn_cast_or_null(terminator); - if (!cond_br) { - llvm::errs() << "No conditional branch found in entry block\n"; - return; + // Flattens blocks into the entry block. + Block *entryBlock = &func.getBody().front(); + SmallVector blocks_to_flatten; + for (Block &block : func.getBody()) { + if (&block != entryBlock) + blocks_to_flatten.push_back(&block); } - // Get condition and create not condition - Location loc = cond_br.getLoc(); - Value cond = cond_br.getCondition(); - builder.setInsertionPoint(cond_br); - auto not_cond = builder.create(loc, cond.getType(), cond); - - // Processes branches. - DenseMap arg_mapping; - DenseMap value_mapping; - DenseSet visited_blocks; - SmallVector true_results, false_results; - - processBlockRecursively(cond_br.getTrueDest(), entry_block, cond, - builder, true_results, visited_blocks, arg_mapping, value_mapping); - processBlockRecursively(cond_br.getFalseDest(), entry_block, not_cond.getResult(), - builder, false_results, visited_blocks, arg_mapping, value_mapping); - - llvm::errs() << "Entry block after processing:\n"; - entry_block.dump(); - - // Creates final return operation. - if (!true_results.empty() && !false_results.empty()) { - builder.setInsertionPoint(cond_br); - auto sel = builder.create( - loc, true_results[0].getType(), true_results[0], false_results[0], cond); - builder.create(loc, sel.getResult()); + // Erases terminators before moving ops into entry block. + for (Block *block : blocks_to_flatten) { + for (Operation &op : llvm::make_early_inc_range(*block)) { + if (isa(op) || isa(op)) { + op.erase(); + } + } } - // Replaces all uses with mapped values. - for (auto &[orig, mapped] : value_mapping) { - orig.replaceAllUsesWith(mapped); + // Moves all operations from blocks to the entry block before the terminator. + for (Block *block : blocks_to_flatten) { + auto &ops = block->getOperations(); + while (!ops.empty()) { + Operation &op = ops.front(); + op.moveBefore(&entryBlock->back()); + } } - // Erases the conditional branch. - cond_br->erase(); - - // Finally erases all other blocks. - SmallVector blocks_to_erase; - for (Block &block : llvm::make_early_inc_range(func.getBody())) { - if (&block != &entry_block) { - blocks_to_erase.push_back(&block); + // Erases any remaining br/cond_br that were moved into the entry block. + for (Operation &op : llvm::make_early_inc_range(*entryBlock)) { + if (isa(op) || isa(op)) { + op.erase(); } } - for (Block *block : blocks_to_erase) { - block->dropAllReferences(); + for (Block *block : blocks_to_flatten) { block->erase(); } - - llvm::errs() << "Function after transformation:\n"; - func.dump(); }); } };