diff --git a/include/TaskflowDialect/TaskflowPasses.h b/include/TaskflowDialect/TaskflowPasses.h index c50544c9..71a3b510 100644 --- a/include/TaskflowDialect/TaskflowPasses.h +++ b/include/TaskflowDialect/TaskflowPasses.h @@ -23,6 +23,7 @@ std::unique_ptr createMapTaskOnCgraPass(); // Optimization Passes //=========================================================// std::unique_ptr createAffineLoopTreeSerializationPass(); +std::unique_ptr createAffineLoopPerfectionPass(); #define GEN_PASS_REGISTRATION #include "TaskflowDialect/TaskflowPasses.h.inc" diff --git a/include/TaskflowDialect/TaskflowPasses.td b/include/TaskflowDialect/TaskflowPasses.td index d41ae666..4bb69caf 100644 --- a/include/TaskflowDialect/TaskflowPasses.td +++ b/include/TaskflowDialect/TaskflowPasses.td @@ -21,6 +21,19 @@ def AffineLoopTreeSerialization : Pass<"affine-loop-tree-serialization", "Module "mlir::func::FuncDialect"]; } +def AffineLoopPerfection : Pass<"affine-loop-perfection", "func::FuncOp">{ + let summary = "Perfectionizes affine.for loops into perfect nested loop bands"; + let description = [{ + This pass transforms affine.for loops into perfect nested loop bands by + applying loop transformations such as loop fusion, loop interchange, and + loop tiling. + }]; + let constructor = "taskflow::createAffineLoopPerfectionPass()"; + let dependentDialects = [ + "mlir::affine::AffineDialect", + "mlir::func::FuncDialect"]; +} + //=========================================================// // Passes for the Taskflow dialect //=========================================================// diff --git a/lib/TaskflowDialect/Transforms/Optimizations/AffineLoopPerfectionPass.cpp b/lib/TaskflowDialect/Transforms/Optimizations/AffineLoopPerfectionPass.cpp new file mode 100644 index 00000000..f22c1a66 --- /dev/null +++ b/lib/TaskflowDialect/Transforms/Optimizations/AffineLoopPerfectionPass.cpp @@ -0,0 +1,412 @@ +#include "TaskflowDialect/TaskflowPasses.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::taskflow; + +namespace { +//================================================================= +// Affine Loop Band Structure. +//================================================================= + +// A loop band can be classified into two types: + +// 1) Perfect Loop Band: A sequence of perfectly nested loops where each loop +// (except the innermost) has exactly one child loop and no other operations (no +// prologue/epilogue); +// 2) Imperfect Loop Band: A sequence of nested loops that +// do not satisfy the perfect nesting condition (e.g., loops with exactly one +// child loop, but with other operations in the body). +using AffineLoopBand = SmallVector; + +// Checks if an operation is side-effect-free (pure computation). +static bool hasSideEffect(Operation *op) { + // Yield operations are terminators, not computations. + if (isa(op)) { + return true; + } + + // Arithmetic and pure operations. + if (isa(op->getDialect())) { + return false; + } + + // affine.load or memref.load is considered side-effect-free (read-only). + if (isa(op) || isa(op)) { + return false; + } + + // affine.store and memref.store are side-effecting (write operations). + if (isa(op)) { + return true; + } + + // For other operations, conservatively assumes they have side effects. + return true; +} + +// Collects loop bands from a function. +static void collectLoopBands(func::FuncOp func_op, + SmallVector &loop_bands) { + func_op.walk([&](affine::AffineForOp for_op) { + // Only processes outermost loops (skips nested loops). + if (for_op->getParentOfType()) { + return; + } + + AffineLoopBand current_band; + affine::AffineForOp current_loop = for_op; + + // Follows the nesting chain to build the perfect loop band. + while (current_loop) { + current_band.push_back(current_loop); + + // Checks if body has exactly one nested loop (perfect nesting). + Block &body = current_loop.getRegion().front(); + affine::AffineForOp nested_loop = nullptr; + size_t num_loops = 0; + + for (Operation &body_op : body) { + if (auto nested_for = dyn_cast(&body_op)) { + nested_loop = nested_for; + num_loops++; + } + } + + // Loop bands condition: exactly 1 nested loop, any number of other ops + // (other ops will be perfectized). + if (num_loops == 1) { + current_loop = nested_loop; + } else { + // Has multiple nested loops, not loop bands. + break; + } + } + + if (!current_band.empty()) { + loop_bands.push_back(current_band); + } + }); +} + +//================================================================= +// Loop Perfection Logic. +//================================================================= + +// Creates a condition checking if all inner loop indices are at their lower +// bounds. Used for prologue condition. +static Value +createPrologueCondition(OpBuilder &builder, Location loc, + ArrayRef inner_loops) { + // Builds condition for prologue code: (i1 == lb1) && (i2 == lb2) && ... + Value condition = nullptr; + + for (affine::AffineForOp loop : inner_loops) { + Value idx = loop.getInductionVar(); + Value lb; + + if (loop.hasConstantLowerBound()) { + lb = builder.create(loc, + loop.getConstantLowerBound()); + } else { + llvm::errs() + << "[LoopPerfection] Non-constant lower bound not supported.\n"; + return nullptr; + } + + Value eq = + builder.create(loc, arith::CmpIPredicate::eq, idx, lb); + + if (condition) { + condition = builder.create(loc, condition, eq); + } else { + condition = eq; + } + } + + return condition; +} + +// Creates a condition checking if all inner loop indices are at their upper +// bounds. Used for epilogue condition. +static Value +createEpilogueCondition(OpBuilder &builder, Location loc, + ArrayRef inner_loops) { + // Builds condition for epilogue code: (i1 == ub1 - 1) && (i2 == ub2 - 1) && + // ... + Value condition = nullptr; + + for (affine::AffineForOp loop : inner_loops) { + Value idx = loop.getInductionVar(); + Value next_idx; // idx + step + Value ub; + + // Gets step. + int32_t step_val = 1; + if (loop.getStepAsInt()) { + step_val = loop.getStepAsInt(); + } else { + llvm::errs() << "[LoopPerfection] Non-constant step not supported.\n"; + return nullptr; + } + + // Computes next_idx = idx + step. + Value step = builder.create(loc, step_val); + next_idx = builder.create(loc, idx, step); + + if (loop.hasConstantUpperBound()) { + ub = builder.create(loc, + loop.getConstantUpperBound()); + } else { + llvm::errs() + << "[LoopPerfection] Non-constant upper bound not supported.\n"; + return nullptr; + } + + Value is_last = builder.create( + loc, arith::CmpIPredicate::sge, next_idx, ub); + + if (condition) { + condition = builder.create(loc, condition, is_last); + } else { + condition = is_last; + } + } + + return condition; +} + +// Applies loop perfection to a single loop band. +// Sinks all operations into the innermost loop with condition execution. +static LogicalResult applyLoopPerfection(AffineLoopBand &loop_band) { + if (loop_band.empty()) { + return failure(); + } + + llvm::errs() << "[LoopPerfection] Processing loop band with " + << loop_band.size() << " loops.\n"; + + affine::AffineForOp innermost_loop = loop_band.back(); + OpBuilder builder(innermost_loop); + + // Processes each loop in the band from outermost to innermost. + for (size_t i = loop_band.size() - 1; i > 0; i--) { + affine::AffineForOp loop = loop_band[i - 1]; + affine::AffineForOp child_loop = loop_band[i]; + + // Collects prologue and epilogue operations in the current loop + // (excluding the child loop). + SmallVector prologue_ops; // Before child loop. + SmallVector epilogue_ops; // After child loop. + + bool is_prologue = true; + for (Operation &op : loop.getRegion().front()) { + if (&op == child_loop) { + is_prologue = false; + continue; + } + + if (isa(&op)) { + // Skips yield operations. + continue; + } + + // Rejects operations that cannot be perfectized. + if (llvm::any_of(op.getResultTypes(), + [](Type type) { return isa(type); })) { + llvm::errs() << "[LoopPerfection] Memref-producing op cannot be " + "perfectized.\n"; + op.dump(); + return failure(); + } + + if (isa(&op)) { + llvm::errs() + << "[LoopPerfection] Function call op cannot be perfectized.\n"; + op.dump(); + return failure(); + } + + if (is_prologue) { + prologue_ops.push_back(&op); + } else { + epilogue_ops.push_back(&op); + } + } + + if (prologue_ops.empty() && epilogue_ops.empty()) { + // No operations to perfect, continues to next loop. + continue; + } + + Location loc = loop.getLoc(); + Block &innermost_body = innermost_loop.getRegion().front(); + + // Gets all inner loops (from current child to innermost loop). + ArrayRef inner_loops = + ArrayRef(loop_band).drop_front(i); + + // Handles prologue operations. + if (!prologue_ops.empty()) { + llvm::errs() << " Moving " << prologue_ops.size() + << " prologue operations\n"; + + Operation *insert_point = &innermost_body.front(); + + // Seperates pure and side-effecting operations in the prologue. + SmallVector pure_ops; + SmallVector side_effect_ops; + + for (Operation *op : prologue_ops) { + if (hasSideEffect(op)) { + side_effect_ops.push_back(op); + } else { + pure_ops.push_back(op); + } + } + + // Moves pure operations directly into the innermost loop (will be CSE'd + // if redundant). + for (Operation *op : pure_ops) { + op->moveBefore(insert_point); + } + + // Moves side-effecting operations into the innermost loop with + // condition execution. + if (!side_effect_ops.empty()) { + builder.setInsertionPoint(insert_point); + Value condition = createPrologueCondition(builder, loc, inner_loops); + + if (condition) { + scf::IfOp if_op = builder.create(loc, condition, + /*withElseRegion*/ false); + + Block *then_block = if_op.thenBlock(); + + for (Operation *op : side_effect_ops) { + op->moveBefore(then_block->getTerminator()); + } + } else { + // If condition creation fails, returns failure to avoid + // incorrect transformation. + llvm::errs() + << "[LoopPerfection] Failed to create prologue condition.\n"; + return failure(); + } + } + } + + // Handles epilogue operations. + if (!epilogue_ops.empty()) { + llvm::errs() << " Moving " << epilogue_ops.size() + << " epilogue operations\n"; + + Operation *insert_point = innermost_body.getTerminator(); + + // Separates pure and side-effecting operations in the epilogue. + SmallVector pure_ops; + SmallVector side_effect_ops; + + for (Operation *op : epilogue_ops) { + if (hasSideEffect(op)) { + side_effect_ops.push_back(op); + } else { + pure_ops.push_back(op); + } + } + + // Moves pure operations directly into the innermost loop (will be CSE'd + // if redundant). + for (Operation *op : pure_ops) { + op->moveBefore(insert_point); + } + + // Moves side-effecting operations into the innermost loop with + // condition execution. + if (!side_effect_ops.empty()) { + builder.setInsertionPoint(insert_point); + Value condition = createEpilogueCondition(builder, loc, inner_loops); + + if (condition) { + scf::IfOp if_op = builder.create(loc, condition, + /*withElseRegion*/ false); + + Block *then_block = if_op.thenBlock(); + + for (Operation *op : side_effect_ops) { + op->moveBefore(then_block->getTerminator()); + } + } else { + // If condition creation fails, returns failure to avoid + // incorrect transformation. + llvm::errs() + << "[LoopPerfection] Failed to create epilogue condition.\n"; + return failure(); + } + } + } + } + + return success(); +} + +//================================================================= +// Pass Implementation. +//================================================================= +struct AffineLoopPerfectionPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineLoopPerfectionPass) + + StringRef getArgument() const final { return "affine-loop-perfection"; } + StringRef getDescription() const final { + return "Apply loop perfection for affine loops."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + void runOnOperation() override { + func::FuncOp func_op = getOperation(); + // Collects all loop bands in the function. + SmallVector loop_bands; + collectLoopBands(func_op, loop_bands); + + if (loop_bands.empty()) { + llvm::errs() << "[LoopPerfection] No loop bands found in function: " + << func_op.getName() << "\n"; + return; + } + + llvm::errs() << "[LoopPerfection] Found " << loop_bands.size() + << " loop bands in function: " << func_op.getName() << "\n"; + + // Apply loop perfection to each loop band. + for (AffineLoopBand &band : loop_bands) { + if (failed(applyLoopPerfection(band))) { + signalPassFailure(); + return; + } + } + } +}; +} // namespace + +std::unique_ptr mlir::taskflow::createAffineLoopPerfectionPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt b/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt index 3e1ce5cd..2200f5b1 100644 --- a/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt +++ b/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt @@ -2,6 +2,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) add_mlir_conversion_library(MLIRTaskflowOptimization AffineLoopTreeSerializationPass.cpp + AffineLoopPerfectionPass.cpp DEPENDS MLIRTaskflowTransformsIncGen diff --git a/test/benchmark/Zeonica_Testbench b/test/benchmark/Zeonica_Testbench index 45e85e44..c7590d83 160000 --- a/test/benchmark/Zeonica_Testbench +++ b/test/benchmark/Zeonica_Testbench @@ -1 +1 @@ -Subproject commit 45e85e44d58670e08a88dfcebfd471909699ae2c +Subproject commit c7590d836df404dca078c4c5104c39673100a4af diff --git a/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir b/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir index 19fa277b..2a4eb496 100644 --- a/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir +++ b/test/multi-cgra/taskflow/irregular-loop/irregular-loop.mlir @@ -2,6 +2,11 @@ // RUN: -o %t.serialized.mlir // RUN: FileCheck %s --input-file=%t.serialized.mlir --check-prefixes=SERIALIZED +// RUN: mlir-neura-opt %s --affine-loop-tree-serialization \ +// RUN: --affine-loop-perfection \ +// RUN: -o %t.perfect.mlir +// RUN: FileCheck %s --input-file=%t.perfect.mlir --check-prefixes=PERFECT + // RUN: mlir-neura-opt %s --affine-loop-tree-serialization \ // RUN: --convert-affine-to-taskflow \ // RUN: -o %t.taskflow.mlir @@ -103,6 +108,45 @@ module attributes {} { // SERIALIZED-NEXT: } // SERIALIZED-NEXT: } +// PERFECT: module { +// PERFECT-NEXT: func.func @_Z21irregularLoopExample1v() -> i32 attributes {llvm.linkage = #llvm.linkage} { +// PERFECT-NEXT: %c2_i32 = arith.constant 2 : i32 +// PERFECT-NEXT: %c8_i32 = arith.constant 8 : i32 +// PERFECT-NEXT: %c0_i32 = arith.constant 0 : i32 +// PERFECT-NEXT: %alloca = memref.alloca() : memref +// PERFECT-NEXT: %alloca_0 = memref.alloca() : memref<4x8xi32> +// PERFECT-NEXT: %0 = affine.for %arg0 = 0 to 5 iter_args(%arg1 = %c0_i32) -> (i32) { +// PERFECT-NEXT: %2 = arith.index_cast %arg0 : index to i32 +// PERFECT-NEXT: %3 = arith.addi %arg1, %2 : i32 +// PERFECT-NEXT: affine.yield %3 : i32 +// PERFECT-NEXT: } +// PERFECT-NEXT: affine.for %arg0 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg1 = 0 to 8 { +// PERFECT-NEXT: %2 = arith.index_cast %arg0 : index to i32 +// PERFECT-NEXT: %3 = arith.muli %2, %c8_i32 : i32 +// PERFECT-NEXT: %4 = arith.index_cast %arg1 : index to i32 +// PERFECT-NEXT: %5 = arith.addi %3, %4 : i32 +// PERFECT-NEXT: affine.store %5, %alloca_0[%arg0, %arg1] : memref<4x8xi32> +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: affine.for %arg0 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg1 = 0 to 8 { +// PERFECT-NEXT: %2 = arith.index_cast %arg0 : index to i32 +// PERFECT-NEXT: %3 = arith.muli %2, %c8_i32 : i32 +// PERFECT-NEXT: %4 = affine.load %alloca_0[%arg0, %arg1] : memref<4x8xi32> +// PERFECT-NEXT: %5 = arith.addi %4, %0 : i32 +// PERFECT-NEXT: affine.if #set(%arg0, %arg1) { +// PERFECT-NEXT: affine.store %5, %alloca[] : memref +// PERFECT-NEXT: %6 = arith.muli %5, %c2_i32 : i32 +// PERFECT-NEXT: affine.store %6, %alloca[] : memref +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: %1 = affine.load %alloca[] : memref +// PERFECT-NEXT: return %1 : i32 +// PERFECT-NEXT: } +// PERFECT-NEXT: } + // TASKFLOW: #set = affine_set<(d0, d1) : (d0 - 3 == 0, d1 - 7 == 0)> // TASKFLOW-NEXT: module { // TASKFLOW-NEXT: func.func @_Z21irregularLoopExample1v() -> i32 attributes {llvm.linkage = #llvm.linkage} { diff --git a/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir b/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir index ebdbe079..e6376f44 100644 --- a/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir +++ b/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir @@ -2,6 +2,11 @@ // RUN: -o %t.serialized.mlir // RUN: FileCheck %s --input-file=%t.serialized.mlir --check-prefixes=SERIALIZED +// RUN: mlir-neura-opt %s --affine-loop-tree-serialization \ +// RUN: --affine-loop-perfection \ +// RUN: -o %t.perfect.mlir +// RUN: FileCheck %s --input-file=%t.perfect.mlir --check-prefixes=PERFECT + // RUN: mlir-neura-opt %s --affine-loop-tree-serialization \ // RUN: --convert-affine-to-taskflow \ // RUN: -o %t.taskflow.mlir @@ -118,6 +123,57 @@ module attributes {} { // SERIALIZED-NEXT: } // SERIALIZED-NEXT: } +// PERFECT: module { +// PERFECT-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// PERFECT-NEXT: affine.for %arg10 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg11 = 0 to 8 { +// PERFECT-NEXT: affine.for %arg12 = 0 to 6 { +// PERFECT-NEXT: %1 = affine.load %arg0[%arg10, %arg11, %arg12] : memref +// PERFECT-NEXT: affine.store %1, %arg5[%arg12] : memref +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: affine.for %arg10 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg11 = 0 to 8 { +// PERFECT-NEXT: affine.for %arg12 = 0 to 5 { +// PERFECT-NEXT: %1 = affine.load %arg1[%arg10, %arg11, %arg12] : memref +// PERFECT-NEXT: %2 = affine.load %arg2[%arg10, %arg11, %arg12] : memref +// PERFECT-NEXT: %3 = arith.addi %1, %2 : i32 +// PERFECT-NEXT: affine.store %3, %arg6[%arg12] : memref +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: affine.for %arg10 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg11 = 0 to 8 { +// PERFECT-NEXT: affine.for %arg12 = 0 to 6 { +// PERFECT-NEXT: %1 = affine.load %arg5[%arg12] : memref +// PERFECT-NEXT: %2 = affine.load %arg6[%arg12] : memref +// PERFECT-NEXT: %3 = arith.addi %1, %2 : i32 +// PERFECT-NEXT: %4 = affine.load %arg9[0] : memref +// PERFECT-NEXT: %5 = arith.addi %4, %3 : i32 +// PERFECT-NEXT: affine.store %5, %arg9[0] : memref +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: affine.for %arg10 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg11 = 0 to 7 { +// PERFECT-NEXT: %1 = affine.load %arg3[%arg10, %arg11] : memref +// PERFECT-NEXT: affine.store %1, %arg7[%arg11] : memref +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: affine.for %arg10 = 0 to 4 { +// PERFECT-NEXT: affine.for %arg11 = 0 to 9 { +// PERFECT-NEXT: %1 = affine.load %arg4[%arg10, %arg11] : memref +// PERFECT-NEXT: %2 = affine.load %arg7[%arg11] : memref +// PERFECT-NEXT: %3 = arith.addi %1, %2 : i32 +// PERFECT-NEXT: affine.store %3, %arg8[%arg11] : memref +// PERFECT-NEXT: } +// PERFECT-NEXT: } +// PERFECT-NEXT: %0 = affine.load %arg9[0] : memref +// PERFECT-NEXT: return %0 : i32 +// PERFECT-NEXT: } +// PERFECT-NEXT: } + // TASKFLOW: module { // TASKFLOW-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { // TASKFLOW-NEXT: %write_outputs = taskflow.task @Task_0 read_memrefs(%arg0 : memref) write_memrefs(%arg5 : memref) [original_read_memrefs(%arg0 : memref), original_write_memrefs(%arg5 : memref)] : (memref, memref) -> (memref) {