From 16bae0aae3020665ccefd880887ad8323040ecfa Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Wed, 11 Feb 2026 16:22:59 +0800 Subject: [PATCH 1/3] prototype memory access streaming fusion pass --- include/TaskflowDialect/TaskflowPasses.h | 1 + include/TaskflowDialect/TaskflowPasses.td | 17 + .../Transforms/Optimizations/CMakeLists.txt | 1 + .../MemoryAccessStreamingFusion.cpp | 961 ++++++++++++++++++ 4 files changed, 980 insertions(+) create mode 100644 lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp diff --git a/include/TaskflowDialect/TaskflowPasses.h b/include/TaskflowDialect/TaskflowPasses.h index bccdb84d..a407b37f 100644 --- a/include/TaskflowDialect/TaskflowPasses.h +++ b/include/TaskflowDialect/TaskflowPasses.h @@ -28,6 +28,7 @@ std::unique_ptr createMapTaskOnCgraPass(); //=========================================================// std::unique_ptr createAffineLoopTreeSerializationPass(); std::unique_ptr createAffineLoopPerfectionPass(); +std::unique_ptr createMemoryAccessStreamingFusionPass(); #define GEN_PASS_REGISTRATION #include "TaskflowDialect/TaskflowPasses.h.inc" diff --git a/include/TaskflowDialect/TaskflowPasses.td b/include/TaskflowDialect/TaskflowPasses.td index 4bb69caf..cd35d62a 100644 --- a/include/TaskflowDialect/TaskflowPasses.td +++ b/include/TaskflowDialect/TaskflowPasses.td @@ -72,4 +72,21 @@ def MapTaskOnCgra : Pass<"map-task-on-cgra", "func::FuncOp"> { }]; let constructor = "taskflow::createMapTaskOnCgraPass()"; } + +def MemoryAccessStreamingFusion : Pass<"memory-access-streaming-fusion", "func::FuncOp"> { + let summary = "Fuses tasks connected by memory dependencies for streaming execution"; + let description = [{ + Identifies and fuses taskflow.task operations that have memory access + dependencies (one task writes a memref, another task reads it). + Eliminates intermediate memref allocations and converts memory access + dependencies into direct SSA value dependencies. + + This is distinct from SSA value producer-consumer dependencies which + already exist in the IR (value-output -> value-input). + + Uses greedy fusion strategy to iteratively fuse the most beneficial + task pairs until no more fusion opportunities exist. + }]; + let constructor = "taskflow::createMemoryAccessStreamingFusionPass()"; +} #endif // TASKFLOW_PASSES_TD \ No newline at end of file diff --git a/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt b/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt index 2200f5b1..9f56a1f3 100644 --- a/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt +++ b/lib/TaskflowDialect/Transforms/Optimizations/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}) add_mlir_conversion_library(MLIRTaskflowOptimization AffineLoopTreeSerializationPass.cpp AffineLoopPerfectionPass.cpp + MemoryAccessStreamingFusion.cpp DEPENDS MLIRTaskflowTransformsIncGen diff --git a/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp b/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp new file mode 100644 index 00000000..e3ebd36f --- /dev/null +++ b/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp @@ -0,0 +1,961 @@ +//===- MemoryAccessStreamingFusion.cpp - Fuse tasks by memory deps -------===// +// +// This pass identifies and fuses taskflow.task operations that are connected +// by memory access dependencies (one task writes a memref, another reads it). +// It eliminates intermediate memref allocations and converts memory access +// dependencies into direct SSA value dependencies. +// +//===----------------------------------------------------------------------===// + +#include "TaskflowDialect/TaskflowDialect.h" +#include "TaskflowDialect/TaskflowOps.h" +#include "TaskflowDialect/TaskflowPasses.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "memory-access-streaming-fusion" + +namespace mlir { +namespace taskflow { + +#define GEN_PASS_DEF_MEMORYACCESSSTREAMINGFUSION +#include "TaskflowDialect/TaskflowPasses.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Task Information and Dependency Analysis +//===----------------------------------------------------------------------===// + +/// Stores information about a single task and its memory dependencies. +struct TaskInfo { + taskflow::TaskflowTaskOp task_op; + + // Memrefs accessed by this task. + SmallVector read_memrefs; + SmallVector write_memrefs; + + // Memory dependency tracking (distinct from SSA value dependencies). + // Tasks that write memrefs which this task reads. + SmallVector memory_writers; + + // Tasks that read memrefs which this task writes. + SmallVector memory_readers; + + // SSA value inputs (true producer-consumer dependencies). + SmallVector value_inputs; + + TaskInfo() : task_op(nullptr) {} + TaskInfo(taskflow::TaskflowTaskOp op) : task_op(op) {} +}; + +/// Represents a candidate pair of tasks that can be fused. +struct FusionCandidate { + TaskInfo *memory_writer; // Task that writes the intermediate memref. + TaskInfo *memory_reader; // Task that reads the intermediate memref. + Value intermediate_memref; // The memref to be eliminated. + int fusion_benefit; // Fusion benefit score for greedy selection. + + FusionCandidate(TaskInfo *writer, TaskInfo *reader, Value memref, int benefit) + : memory_writer(writer), memory_reader(reader), + intermediate_memref(memref), fusion_benefit(benefit) {} +}; + +//===----------------------------------------------------------------------===// +// Memory Dependency Graph Builder +//===----------------------------------------------------------------------===// + +/// Builds the memory dependency graph for all tasks in the function. +class MemoryDependencyAnalysis { +public: + MemoryDependencyAnalysis(func::FuncOp func) : function(func) {} + + /// Analyzes all tasks and builds the memory dependency graph. + void analyze(DenseMap &task_map) { + // Collects all task operations. + SmallVector tasks; + function.walk([&](taskflow::TaskflowTaskOp task_op) { + tasks.push_back(task_op); + task_map[task_op.getOperation()] = TaskInfo(task_op); + }); + + // Extracts memref accesses for each task. + for (auto task_op : tasks) { + auto &task_info = task_map[task_op.getOperation()]; + extractMemrefAccesses(task_op, task_info); + } + + // Builds memory dependency edges. + buildMemoryDependencies(tasks, task_map); + } + +private: + /// Extracts read and write memrefs from a task operation. + void extractMemrefAccesses(taskflow::TaskflowTaskOp task_op, + TaskInfo &task_info) { + // Extracts read memrefs from the task operands. + for (Value memref : task_op.getReadMemrefs()) { + task_info.read_memrefs.push_back(memref); + } + + // Extracts write memrefs from the task operands. + for (Value memref : task_op.getWriteMemrefs()) { + task_info.write_memrefs.push_back(memref); + } + + // Extracts value inputs (SSA producer-consumer dependencies). + for (Value input : task_op.getValueInputs()) { + task_info.value_inputs.push_back(input); + } + } + + /// Builds memory dependency edges between tasks. + /// Uses original_read/write_memrefs for matching because the IR + /// passes writer's write_outputs (SSA results) as reader's read_memrefs, + /// so the raw alloc values are only preserved in original_read/write_memrefs. + void buildMemoryDependencies(ArrayRef tasks, + DenseMap &task_map) { + // Maps each original write memref to the task that writes it. + DenseMap orig_memref_writers; + + // First pass: records all original memref writers. + for (auto task_op : tasks) { + auto &task_info = task_map[task_op.getOperation()]; + for (Value orig_write : task_op.getOriginalWriteMemrefs()) { + orig_memref_writers[orig_write] = &task_info; + } + } + + // Second pass: establishes memory dependencies using original memrefs. + for (auto task_op : tasks) { + auto &task_info = task_map[task_op.getOperation()]; + for (Value orig_read : task_op.getOriginalReadMemrefs()) { + auto it = orig_memref_writers.find(orig_read); + if (it != orig_memref_writers.end()) { + TaskInfo *writer = it->second; + // Don't create self-dependency. + if (writer == &task_info) + continue; + // Establishes bidirectional memory dependency. + task_info.memory_writers.push_back(writer); + writer->memory_readers.push_back(&task_info); + } + } + } + } + + func::FuncOp function; +}; + +//===----------------------------------------------------------------------===// +// Fusion Candidate Identification +//===----------------------------------------------------------------------===// + +/// Identifies viable fusion candidates using greedy strategy. +class FusionCandidateIdentifier { +public: + FusionCandidateIdentifier(DenseMap &map) + : task_map(map) {} + + /// Identifies all valid fusion candidates and sorts them by benefit. + SmallVector identify() { + SmallVector candidates; + + // Iterates through all tasks to find fusion opportunities. + for (auto &entry : task_map) { + TaskInfo &task_info = entry.second; + + // Checks if this task has exactly one memory reader (fusion condition). + if (task_info.memory_readers.size() == 1) { + TaskInfo *reader = task_info.memory_readers[0]; + + // Checks if fusion is valid. + if (canFuse(&task_info, reader)) { + // Finds the intermediate memref to eliminate. + Value intermediate = findIntermediateMemref(&task_info, reader); + if (intermediate) { + int benefit = calculateFusionBenefit(&task_info, reader); + candidates.emplace_back(&task_info, reader, intermediate, benefit); + } + } + } + } + + // Sorts candidates by fusion benefit (highest first) for greedy selection. + llvm::sort(candidates, + [](const FusionCandidate &a, const FusionCandidate &b) { + return a.fusion_benefit > b.fusion_benefit; + }); + + return candidates; + } + +private: + /// Checks if two tasks can be fused. + bool canFuse(TaskInfo *writer, TaskInfo *reader) { + auto writer_op = writer->task_op; + auto reader_op = reader->task_op; + + // 1. Extracts loop nests from both tasks and checks bounds compatibility. + auto writer_loops = extractOutermostLoopNest(writer_op); + auto reader_loops = extractOutermostLoopNest(reader_op); + + if (writer_loops.empty() || reader_loops.empty()) { + LLVM_DEBUG(llvm::dbgs() << " canFuse: empty loop nest\n"); + return false; + } + + if (!areLoopBoundsCompatible(writer_loops, reader_loops)) { + LLVM_DEBUG(llvm::dbgs() << " canFuse: incompatible loop bounds\n"); + return false; + } + + // 2. Checks that the intermediate memref is not used outside + // writer/reader (e.g., not returned by the function). + Value intermediate = findIntermediateMemref(writer, reader); + if (!intermediate) + return false; + + for (Operation *user : intermediate.getUsers()) { + if (user == writer_op.getOperation() || user == reader_op.getOperation()) + continue; + // Allow memref.alloc which defines it. + if (isa(user)) + continue; + LLVM_DEBUG(llvm::dbgs() + << " canFuse: intermediate memref has external use\n"); + return false; + } + + // 3. No cyclic dependency: writer must not read any original memref + // that reader writes (excluding the intermediate itself). + for (Value w_read : writer_op.getOriginalReadMemrefs()) { + if (w_read == intermediate) + continue; + for (Value r_write : reader_op.getOriginalWriteMemrefs()) { + if (r_write == intermediate) + continue; + if (w_read == r_write) { + LLVM_DEBUG(llvm::dbgs() << " canFuse: cyclic dependency\n"); + return false; + } + } + } + + return true; + } + + /// Extracts the outermost chain of perfectly nested affine.for loops + /// from a task body. + SmallVector + extractOutermostLoopNest(taskflow::TaskflowTaskOp task_op) { + SmallVector loops; + Block &body = task_op.getBody().front(); + + // Finds the single outermost affine.for in the task body. + affine::AffineForOp outermost = nullptr; + for (Operation &op : body) { + if (auto for_op = dyn_cast(op)) { + if (outermost) + return {}; // Multiple top-level loops: bail. + outermost = for_op; + } + } + if (!outermost) + return {}; + + // Walks the perfectly nested chain. + auto current = outermost; + while (current) { + loops.push_back(current); + // Checks for a single nested affine.for. + affine::AffineForOp nested = nullptr; + for (Operation &op : current.getBody()->getOperations()) { + if (auto for_op = dyn_cast(op)) { + if (nested) { + nested = nullptr; + break; + } // Non-perfect: stop. + nested = for_op; + } + } + current = nested; + } + return loops; + } + + /// Checks if two loop nests have compatible bounds. + bool areLoopBoundsCompatible(SmallVector &writer_loops, + SmallVector &reader_loops) { + if (writer_loops.size() != reader_loops.size()) + return false; + for (size_t i = 0; i < writer_loops.size(); ++i) { + if (!writer_loops[i].hasConstantBounds() || + !reader_loops[i].hasConstantBounds()) + return false; + if (writer_loops[i].getConstantLowerBound() != + reader_loops[i].getConstantLowerBound() || + writer_loops[i].getConstantUpperBound() != + reader_loops[i].getConstantUpperBound() || + writer_loops[i].getStepAsInt() != reader_loops[i].getStepAsInt()) + return false; + } + return true; + } + + /// Finds the intermediate memref (original alloc) between writer and reader. + /// Uses original_write/read_memrefs since the IR passes SSA results + /// (not raw allocs) as read_memrefs. + Value findIntermediateMemref(TaskInfo *writer, TaskInfo *reader) { + auto writer_op = writer->task_op; + auto reader_op = reader->task_op; + for (Value orig_write : writer_op.getOriginalWriteMemrefs()) { + for (Value orig_read : reader_op.getOriginalReadMemrefs()) { + if (orig_write == orig_read) { + return orig_write; + } + } + } + return nullptr; + } + + /// Calculates the fusion benefit score for greedy selection. + int calculateFusionBenefit(TaskInfo *writer, TaskInfo *reader) { + int benefit = 0; + + // Base benefit: eliminates one memref allocation. + benefit += 100; + + // Bonus for element-wise operations (same loop bounds, same memref shape). + if (writer->write_memrefs.size() == 1 && reader->read_memrefs.size() == 1) + benefit += 50; + + return benefit; + } + + DenseMap &task_map; +}; + +//===----------------------------------------------------------------------===// +// Task Fuser - Performs the actual fusion transformation +//===----------------------------------------------------------------------===// + +/// Fuses two tasks connected by a memory dependency into a single task. +class TaskFuser { +public: + TaskFuser(func::FuncOp func) : function(func) {} + + /// Performs fusion of a candidate pair. Returns true if fusion succeeded. + bool performFusion(FusionCandidate &candidate) { + auto writer_op = candidate.memory_writer->task_op; + auto reader_op = candidate.memory_reader->task_op; + Value intermediate = candidate.intermediate_memref; + + LLVM_DEBUG(llvm::dbgs() + << " Fusing writer: " << writer_op.getTaskName() + << " + reader: " << reader_op.getTaskName() << "\n"); + + // Step 1: Builds merged operand lists. + SmallVector fused_read_memrefs; + SmallVector fused_write_memrefs; + SmallVector fused_value_inputs; + SmallVector fused_original_read_memrefs; + SmallVector fused_original_write_memrefs; + + buildMergedOperands(writer_op, reader_op, intermediate, fused_read_memrefs, + fused_write_memrefs, fused_value_inputs, + fused_original_read_memrefs, + fused_original_write_memrefs); + + // Step 2: Builds the result types (same as reader's outputs). + SmallVector write_output_types; + SmallVector value_output_types; + for (Value v : reader_op.getWriteOutputs()) + write_output_types.push_back(v.getType()); + for (Value v : reader_op.getValueOutputs()) + value_output_types.push_back(v.getType()); + + // Step 3: Creates the fused task. + OpBuilder builder(reader_op); + std::string fused_name = + (writer_op.getTaskName() + "_" + reader_op.getTaskName() + "_fused") + .str(); + + auto fused_task = builder.create( + writer_op.getLoc(), write_output_types, value_output_types, + fused_read_memrefs, fused_write_memrefs, fused_value_inputs, fused_name, + fused_original_read_memrefs, fused_original_write_memrefs); + + // Step 4: Builds fused task body by merging loop nests. + if (!buildFusedBody(fused_task, writer_op, reader_op, intermediate, + fused_read_memrefs, fused_write_memrefs, + fused_value_inputs)) { + fused_task.erase(); + return false; + } + + // Step 5: Replaces uses and cleans up. + replaceUsesAndCleanup(writer_op, reader_op, fused_task, intermediate); + + LLVM_DEBUG(llvm::dbgs() << " Fusion succeeded: " << fused_name << "\n"); + return true; + } + +private: + /// Builds merged operand lists for the fused task. + /// The intermediate is the original alloc value. We must use + /// original_read/write_memrefs to identify which operands to exclude + /// from the reader (since reader's read_memrefs contain SSA results). + void buildMergedOperands(taskflow::TaskflowTaskOp writer_op, + taskflow::TaskflowTaskOp reader_op, + Value intermediate, SmallVector &fused_reads, + SmallVector &fused_writes, + SmallVector &fused_values, + SmallVector &fused_orig_reads, + SmallVector &fused_orig_writes) { + + // read_memrefs = writer.reads ∪ reader.reads - intermediate + DenseSet seen; + auto writer_reads = writer_op.getReadMemrefs(); + auto writer_orig_reads = writer_op.getOriginalReadMemrefs(); + for (unsigned i = 0; i < writer_reads.size(); ++i) { + Value orig = (i < writer_orig_reads.size()) ? writer_orig_reads[i] + : writer_reads[i]; + if (orig != intermediate && seen.insert(writer_reads[i]).second) + fused_reads.push_back(writer_reads[i]); + } + + auto reader_reads = reader_op.getReadMemrefs(); + auto reader_orig_reads = reader_op.getOriginalReadMemrefs(); + for (unsigned i = 0; i < reader_reads.size(); ++i) { + Value orig = (i < reader_orig_reads.size()) ? reader_orig_reads[i] + : reader_reads[i]; + if (orig != intermediate && seen.insert(reader_reads[i]).second) + fused_reads.push_back(reader_reads[i]); + } + + // write_memrefs = reader.writes ∪ (writer.writes - intermediate) + seen.clear(); + auto reader_writes = reader_op.getWriteMemrefs(); + auto reader_orig_writes = reader_op.getOriginalWriteMemrefs(); + for (unsigned i = 0; i < reader_writes.size(); ++i) { + Value orig = (i < reader_orig_writes.size()) ? reader_orig_writes[i] + : reader_writes[i]; + if (orig != intermediate && seen.insert(reader_writes[i]).second) + fused_writes.push_back(reader_writes[i]); + } + + auto writer_writes = writer_op.getWriteMemrefs(); + auto writer_orig_writes = writer_op.getOriginalWriteMemrefs(); + for (unsigned i = 0; i < writer_writes.size(); ++i) { + Value orig = (i < writer_orig_writes.size()) ? writer_orig_writes[i] + : writer_writes[i]; + if (orig != intermediate && seen.insert(writer_writes[i]).second) + fused_writes.push_back(writer_writes[i]); + } + + // value_inputs = writer.values ∪ reader.values + for (Value v : writer_op.getValueInputs()) + fused_values.push_back(v); + for (Value v : reader_op.getValueInputs()) + fused_values.push_back(v); + + // original_read/write_memrefs: same merge rules (using originals directly). + seen.clear(); + for (Value v : writer_op.getOriginalReadMemrefs()) { + if (v != intermediate && seen.insert(v).second) + fused_orig_reads.push_back(v); + } + for (Value v : reader_op.getOriginalReadMemrefs()) { + if (v != intermediate && seen.insert(v).second) + fused_orig_reads.push_back(v); + } + + seen.clear(); + for (Value v : reader_op.getOriginalWriteMemrefs()) { + if (v != intermediate && seen.insert(v).second) + fused_orig_writes.push_back(v); + } + for (Value v : writer_op.getOriginalWriteMemrefs()) { + if (v != intermediate && seen.insert(v).second) + fused_orig_writes.push_back(v); + } + } + + /// Builds the fused task body by merging writer and reader loop nests. + /// Returns false if fusion fails (e.g., unexpected IR structure). + bool buildFusedBody(taskflow::TaskflowTaskOp fused_task, + taskflow::TaskflowTaskOp writer_op, + taskflow::TaskflowTaskOp reader_op, Value intermediate, + ArrayRef fused_reads, ArrayRef fused_writes, + ArrayRef fused_values) { + + // Creates the entry block with all operands as block arguments. + Block *fused_block = new Block(); + fused_task.getBody().push_back(fused_block); + + // Block args: read_memrefs, write_memrefs, value_inputs. + for (Value v : fused_reads) + fused_block->addArgument(v.getType(), v.getLoc()); + for (Value v : fused_writes) + fused_block->addArgument(v.getType(), v.getLoc()); + for (Value v : fused_values) + fused_block->addArgument(v.getType(), v.getLoc()); + + // Builds a mapping from writer/reader block args to fused block args. + IRMapping writer_mapping; + IRMapping reader_mapping; + + Block &writer_body = writer_op.getBody().front(); + Block &reader_body = reader_op.getBody().front(); + + // Maps writer's block args to fused block args. + unsigned fused_arg_idx = 0; + mapBlockArgs(writer_op, writer_body, fused_block, fused_arg_idx, + writer_mapping, intermediate, fused_reads, fused_writes, + fused_values); + + // Maps reader's block args to fused block args. + mapBlockArgs(reader_op, reader_body, fused_block, fused_arg_idx, + reader_mapping, intermediate, fused_reads, fused_writes, + fused_values); + + // Clones the writer's loop nest into the fused body. + OpBuilder body_builder(fused_block, fused_block->end()); + + // Finds the writer's outermost affine.for. + affine::AffineForOp writer_outer_loop = nullptr; + for (Operation &op : writer_body) { + if (auto for_op = dyn_cast(op)) { + writer_outer_loop = for_op; + break; + } + } + if (!writer_outer_loop) + return false; + + // Finds the reader's outermost affine.for. + affine::AffineForOp reader_outer_loop = nullptr; + for (Operation &op : reader_body) { + if (auto for_op = dyn_cast(op)) { + reader_outer_loop = for_op; + break; + } + } + if (!reader_outer_loop) + return false; + + // Clones the writer's entire loop nest. + Operation *cloned_writer = + body_builder.clone(*writer_outer_loop, writer_mapping); + auto cloned_writer_loop = cast(cloned_writer); + + // Finds the innermost loop body in the cloned writer nest. + affine::AffineForOp innermost_writer = cloned_writer_loop; + while (true) { + affine::AffineForOp nested = nullptr; + for (Operation &op : innermost_writer.getBody()->getOperations()) { + if (auto for_op = dyn_cast(op)) { + nested = for_op; + break; + } + } + if (!nested) + break; + innermost_writer = nested; + } + + // Maps reader's loop induction variables to writer's (cloned) loop IVs. + // This is valid because we verified bounds compatibility in canFuse. + { + auto writer_loops_chain = getLoopChain(cloned_writer_loop); + auto reader_loops_chain = getLoopChain(reader_outer_loop); + for (size_t i = 0; + i < writer_loops_chain.size() && i < reader_loops_chain.size(); + ++i) { + reader_mapping.map(reader_loops_chain[i].getInductionVar(), + writer_loops_chain[i].getInductionVar()); + } + } + + // In the innermost writer loop body, finds affine.store to intermediate. + // Maps: for each store to intermediate, the stored value becomes the + // replacement for the corresponding load in the reader. + // For now, handle the common case: single store to intermediate. + Value store_value = nullptr; + Operation *store_to_intermediate = nullptr; + for (Operation &op : innermost_writer.getBody()->getOperations()) { + if (auto store_op = dyn_cast(op)) { + // Checks if this store writes to the intermediate memref's + // block arg (mapped from the writer's original arg). + // The stored-to memref is the writer's block arg for the intermediate. + store_value = store_op.getValueToStore(); + store_to_intermediate = &op; + } + } + + if (!store_value) { + LLVM_DEBUG(llvm::dbgs() + << " No store to intermediate found in writer\n"); + return false; + } + + // Now clones reader's innermost loop body ops into the writer's + // innermost loop. For affine.load from intermediate, replaces with + // the store_value (SSA direct connection). + affine::AffineForOp reader_innermost = reader_outer_loop; + while (true) { + affine::AffineForOp nested = nullptr; + for (Operation &op : reader_innermost.getBody()->getOperations()) { + if (auto for_op = dyn_cast(op)) { + nested = for_op; + break; + } + } + if (!nested) + break; + reader_innermost = nested; + } + + // Inserts reader's ops before the yield (or end) of the writer's + // innermost loop body. + OpBuilder inner_builder(innermost_writer.getBody(), + innermost_writer.getBody()->end()); + // Positions before the affine.yield terminator if it exists. + if (!innermost_writer.getBody()->empty()) { + Operation *terminator = innermost_writer.getBody()->getTerminator(); + if (terminator) + inner_builder.setInsertionPoint(terminator); + } + + for (Operation &op : reader_innermost.getBody()->getOperations()) { + if (op.hasTrait()) + continue; + + if (auto load_op = dyn_cast(op)) { + // Checks if this load reads from the intermediate memref. + // The reader's block arg corresponding to the intermediate. + Value load_memref = load_op.getMemRef(); + bool is_intermediate_load = false; + + // Checks if this memref is a block arg that maps to intermediate. + // Uses original_read/write_memrefs since reader's read_memrefs + // contain SSA results, not the raw alloc. + if (auto block_arg = dyn_cast(load_memref)) { + unsigned arg_num = block_arg.getArgNumber(); + unsigned total_reads = reader_op.getReadMemrefs().size(); + unsigned total_writes = reader_op.getWriteMemrefs().size(); + + if (arg_num < total_reads) { + // Use original_read_memrefs to check against intermediate. + Value orig_memref = reader_op.getOriginalReadMemrefs()[arg_num]; + if (orig_memref == intermediate) + is_intermediate_load = true; + } else if (arg_num < total_reads + total_writes) { + Value orig_memref = + reader_op.getOriginalWriteMemrefs()[arg_num - total_reads]; + if (orig_memref == intermediate) + is_intermediate_load = true; + } + } + + if (is_intermediate_load) { + // Replaces this load with the store_value (SSA streaming). + reader_mapping.map(load_op.getResult(), store_value); + continue; // Don't clone the load. + } + } + + // Clones the op with the reader mapping. + inner_builder.clone(op, reader_mapping); + } + + // Removes the store to intermediate (no longer needed). + if (store_to_intermediate) + store_to_intermediate->erase(); + + // Step 5: Creates the yield for the fused task. + // Yields the reader's output memrefs. + // Remove any existing terminator (if the block already has one). + if (fused_block->mightHaveTerminator()) { + if (auto *yield_point = fused_block->getTerminator()) + yield_point->erase(); + } + + OpBuilder yield_builder(fused_block, fused_block->end()); + SmallVector yield_writes; + SmallVector yield_values; + + // Finds the yield in the reader's original body for reference. + auto reader_yield = + cast(reader_body.getTerminator()); + + // Maps reader yield's memory results to fused block args. + for (Value v : reader_yield.getMemoryResults()) { + if (reader_mapping.contains(v)) + yield_writes.push_back(reader_mapping.lookup(v)); + else + yield_writes.push_back(v); + } + for (Value v : reader_yield.getValueResults()) { + if (reader_mapping.contains(v)) + yield_values.push_back(reader_mapping.lookup(v)); + else + yield_values.push_back(v); + } + + yield_builder.create(reader_op.getLoc(), + yield_writes, yield_values); + + return true; + } + + /// Maps a task's block args to the corresponding fused block args. + void mapBlockArgs(taskflow::TaskflowTaskOp task_op, Block &original_body, + Block *fused_block, unsigned &fused_arg_idx, + IRMapping &mapping, Value intermediate, + ArrayRef fused_reads, ArrayRef fused_writes, + ArrayRef fused_values) { + + unsigned orig_arg_idx = 0; + unsigned num_reads = task_op.getReadMemrefs().size(); + unsigned num_writes = task_op.getWriteMemrefs().size(); + unsigned num_values = task_op.getValueInputs().size(); + + // Maps read_memrefs block args. + // Uses original_read_memrefs to identify intermediate (since reader's + // read_memrefs contain SSA results, not raw allocs). + auto orig_reads = task_op.getOriginalReadMemrefs(); + for (unsigned i = 0; i < num_reads; ++i) { + Value orig_memref = + (i < orig_reads.size()) ? orig_reads[i] : task_op.getReadMemrefs()[i]; + if (orig_memref == intermediate) { + // Intermediate memref — no corresponding fused arg. Skip. + orig_arg_idx++; + continue; + } + // Finds the fused block arg for this outer memref. + Value outer_memref = task_op.getReadMemrefs()[i]; + int fused_idx = findInFusedArgs(outer_memref, fused_reads, fused_writes, + fused_values); + if (fused_idx >= 0) { + mapping.map(original_body.getArgument(orig_arg_idx), + fused_block->getArgument(fused_idx)); + } + orig_arg_idx++; + } + + // Maps write_memrefs block args. + auto orig_writes = task_op.getOriginalWriteMemrefs(); + for (unsigned i = 0; i < num_writes; ++i) { + Value orig_memref = (i < orig_writes.size()) + ? orig_writes[i] + : task_op.getWriteMemrefs()[i]; + if (orig_memref == intermediate) { + orig_arg_idx++; + continue; + } + Value outer_memref = task_op.getWriteMemrefs()[i]; + int fused_idx = findInFusedArgs(outer_memref, fused_reads, fused_writes, + fused_values); + if (fused_idx >= 0) { + mapping.map(original_body.getArgument(orig_arg_idx), + fused_block->getArgument(fused_idx)); + } + orig_arg_idx++; + } + + // Maps value_inputs block args. + for (unsigned i = 0; i < num_values; ++i) { + Value outer_value = task_op.getValueInputs()[i]; + int fused_idx = + findInFusedArgs(outer_value, fused_reads, fused_writes, fused_values); + if (fused_idx >= 0) { + mapping.map(original_body.getArgument(orig_arg_idx), + fused_block->getArgument(fused_idx)); + } + orig_arg_idx++; + } + } + + /// Finds the index of an outer value in the fused block's argument list. + /// Returns -1 if not found. + int findInFusedArgs(Value outer_val, ArrayRef fused_reads, + ArrayRef fused_writes, + ArrayRef fused_values) { + unsigned idx = 0; + for (Value v : fused_reads) { + if (v == outer_val) + return idx; + idx++; + } + for (Value v : fused_writes) { + if (v == outer_val) + return idx; + idx++; + } + for (Value v : fused_values) { + if (v == outer_val) + return idx; + idx++; + } + return -1; + } + + /// Gets the chain of nested affine.for ops starting from the outermost. + SmallVector getLoopChain(affine::AffineForOp outermost) { + SmallVector chain; + auto current = outermost; + while (current) { + chain.push_back(current); + affine::AffineForOp nested = nullptr; + for (Operation &op : current.getBody()->getOperations()) { + if (auto for_op = dyn_cast(op)) { + nested = for_op; + break; + } + } + current = nested; + } + return chain; + } + + /// Replaces uses of original tasks' results with fused task results + /// and erases original ops. + void replaceUsesAndCleanup(taskflow::TaskflowTaskOp writer_op, + taskflow::TaskflowTaskOp reader_op, + taskflow::TaskflowTaskOp fused_task, + Value intermediate) { + + // Replaces reader's write_outputs with fused task's write_outputs. + for (unsigned i = 0; i < reader_op.getWriteOutputs().size(); ++i) { + reader_op.getWriteOutputs()[i].replaceAllUsesWith( + fused_task.getWriteOutputs()[i]); + } + + // Replaces reader's value_outputs with fused task's value_outputs. + for (unsigned i = 0; i < reader_op.getValueOutputs().size(); ++i) { + reader_op.getValueOutputs()[i].replaceAllUsesWith( + fused_task.getValueOutputs()[i]); + } + + // Erases original tasks (reader first since writer might be used by it + // through the intermediate, but we've already replaced all uses). + reader_op.erase(); + + // Writer's outputs: The intermediate memref output is no longer used. + // Other outputs should have been handled, but let's verify. + // If the writer has other outputs besides the intermediate, those + // should not exist in the single-reader case. + writer_op.erase(); + + // Erases the intermediate memref allocation if it's now dead. + if (auto alloc_op = intermediate.getDefiningOp()) { + if (alloc_op.getResult().use_empty()) + alloc_op.erase(); + } + } + + func::FuncOp function; +}; + +//===----------------------------------------------------------------------===// +// Memory Access Streaming Fusion Pass +//===----------------------------------------------------------------------===// + +struct MemoryAccessStreamingFusionPass + : public impl::MemoryAccessStreamingFusionBase< + MemoryAccessStreamingFusionPass> { + + void runOnOperation() override { + func::FuncOp func = getOperation(); + + LLVM_DEBUG(llvm::dbgs() + << "Running MemoryAccessStreamingFusion on function: " + << func.getName() << "\n"); + + // Iterative fusion: re-analyze after each round to catch chains. + // e.g., A→B→C: first round fuses A+B, second round fuses (A+B)+C. + unsigned total_fusions = 0; + constexpr unsigned kMaxIterations = 100; + + for (unsigned iter = 0; iter < kMaxIterations; ++iter) { + // Re-builds memory dependency graph from current IR state. + DenseMap task_map; + MemoryDependencyAnalysis analysis(func); + analysis.analyze(task_map); + + LLVM_DEBUG(llvm::dbgs() << "Iteration " << iter << ": Found " + << task_map.size() << " tasks\n"); + + // Identifies fusion candidates. + FusionCandidateIdentifier identifier(task_map); + auto candidates = identifier.identify(); + + LLVM_DEBUG(llvm::dbgs() + << "Found " << candidates.size() << " fusion candidates\n"); + + if (candidates.empty()) + break; + + // Performs greedy fusion for this round. + DenseSet fused_tasks; + TaskFuser fuser(func); + unsigned round_fusions = 0; + + for (auto &candidate : candidates) { + Operation *writer_op = candidate.memory_writer->task_op.getOperation(); + Operation *reader_op = candidate.memory_reader->task_op.getOperation(); + + // Skips if either task was already consumed by a previous fusion + // in this round. + if (fused_tasks.count(writer_op) || fused_tasks.count(reader_op)) + continue; + + LLVM_DEBUG(llvm::dbgs() << "Attempting to fuse tasks (benefit: " + << candidate.fusion_benefit << ")\n"); + + if (fuser.performFusion(candidate)) { + fused_tasks.insert(writer_op); + fused_tasks.insert(reader_op); + ++round_fusions; + } + } + + LLVM_DEBUG(llvm::dbgs() << "Round " << iter << ": fused " << round_fusions + << " task pairs\n"); + + total_fusions += round_fusions; + + // If no fusions happened this round, we've converged. + if (round_fusions == 0) + break; + } + + LLVM_DEBUG(llvm::dbgs() << "Total fusions: " << total_fusions << "\n"); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Registration +//===----------------------------------------------------------------------===// + +std::unique_ptr createMemoryAccessStreamingFusionPass() { + return std::make_unique(); +} + +} // namespace taskflow +} // namespace mlir From e41cb8d637248ebfb629c2360049afe727e7e39a Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Wed, 11 Feb 2026 21:58:10 +0800 Subject: [PATCH 2/3] change the canFuse logic --- include/TaskflowDialect/TaskflowPasses.td | 3 - .../MemoryAccessStreamingFusion.cpp | 378 +++++++++++------- 2 files changed, 235 insertions(+), 146 deletions(-) diff --git a/include/TaskflowDialect/TaskflowPasses.td b/include/TaskflowDialect/TaskflowPasses.td index cd35d62a..04d40bc4 100644 --- a/include/TaskflowDialect/TaskflowPasses.td +++ b/include/TaskflowDialect/TaskflowPasses.td @@ -81,9 +81,6 @@ def MemoryAccessStreamingFusion : Pass<"memory-access-streaming-fusion", "func:: Eliminates intermediate memref allocations and converts memory access dependencies into direct SSA value dependencies. - This is distinct from SSA value producer-consumer dependencies which - already exist in the IR (value-output -> value-input). - Uses greedy fusion strategy to iteratively fuse the most beneficial task pairs until no more fusion opportunities exist. }]; diff --git a/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp b/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp index e3ebd36f..f05f68f9 100644 --- a/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp +++ b/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp @@ -1,4 +1,4 @@ -//===- MemoryAccessStreamingFusion.cpp - Fuse tasks by memory deps -------===// +//===- MemoryAccessStreamingFusion.cpp - Fuses tasks by memory deps -------===// // // This pass identifies and fuses taskflow.task operations that are connected // by memory access dependencies (one task writes a memref, another reads it). @@ -20,25 +20,21 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "memory-access-streaming-fusion" - -namespace mlir { -namespace taskflow { - -#define GEN_PASS_DEF_MEMORYACCESSSTREAMINGFUSION -#include "TaskflowDialect/TaskflowPasses.h.inc" +using namespace mlir; +using namespace mlir::taskflow; namespace { //===----------------------------------------------------------------------===// -// Task Information and Dependency Analysis +// Task Information and Dependency Analysis. //===----------------------------------------------------------------------===// -/// Stores information about a single task and its memory dependencies. +// Stores information about a single task and its memory dependencies. struct TaskInfo { taskflow::TaskflowTaskOp task_op; @@ -46,11 +42,15 @@ struct TaskInfo { SmallVector read_memrefs; SmallVector write_memrefs; - // Memory dependency tracking (distinct from SSA value dependencies). - // Tasks that write memrefs which this task reads. + // Memory dependency graph edges. + // Upstream producers: tasks whose write_memrefs this task reads from. + // e.g., if Task A writes %alloc and this task reads %alloc, + // then Task A is in this->memory_writers. SmallVector memory_writers; - // Tasks that read memrefs which this task writes. + // Downstream consumers: tasks that read from this task's write_memrefs. + // e.g., if this task writes %alloc and Task B reads %alloc, + // then Task B is in this->memory_readers. SmallVector memory_readers; // SSA value inputs (true producer-consumer dependencies). @@ -60,7 +60,7 @@ struct TaskInfo { TaskInfo(taskflow::TaskflowTaskOp op) : task_op(op) {} }; -/// Represents a candidate pair of tasks that can be fused. +// Represents a candidate pair of tasks that can be fused. struct FusionCandidate { TaskInfo *memory_writer; // Task that writes the intermediate memref. TaskInfo *memory_reader; // Task that reads the intermediate memref. @@ -76,16 +76,16 @@ struct FusionCandidate { // Memory Dependency Graph Builder //===----------------------------------------------------------------------===// -/// Builds the memory dependency graph for all tasks in the function. +// Builds the memory dependency graph for all tasks in the function. class MemoryDependencyAnalysis { public: - MemoryDependencyAnalysis(func::FuncOp func) : function(func) {} + MemoryDependencyAnalysis(func::FuncOp func_op) : func_op(func_op) {} - /// Analyzes all tasks and builds the memory dependency graph. + // Analyzes all tasks and builds the memory dependency graph. void analyze(DenseMap &task_map) { // Collects all task operations. SmallVector tasks; - function.walk([&](taskflow::TaskflowTaskOp task_op) { + this->func_op.walk([&](taskflow::TaskflowTaskOp task_op) { tasks.push_back(task_op); task_map[task_op.getOperation()] = TaskInfo(task_op); }); @@ -101,7 +101,7 @@ class MemoryDependencyAnalysis { } private: - /// Extracts read and write memrefs from a task operation. + // Extracts read and write memrefs from a task operation. void extractMemrefAccesses(taskflow::TaskflowTaskOp task_op, TaskInfo &task_info) { // Extracts read memrefs from the task operands. @@ -120,55 +120,85 @@ class MemoryDependencyAnalysis { } } - /// Builds memory dependency edges between tasks. - /// Uses original_read/write_memrefs for matching because the IR - /// passes writer's write_outputs (SSA results) as reader's read_memrefs, - /// so the raw alloc values are only preserved in original_read/write_memrefs. + // Builds memory dependency edges between tasks by tracing SSA value flow. + // + // The two sets of read/write memrefs serve distinct purposes: + // - read/write_memrefs: carry SSA values (write_outputs results from + // upstream tasks). These encode memory ordering (RAW, WAW, WAR deps). + // - original_read/write_memrefs: track the physical %alloc buffers. + // These are used later to identify intermediate memrefs to eliminate. + // + // This function builds the dependency graph using the SSA value flow: + // If Task B's read_memrefs or write_memrefs contains a value produced + // by Task A's write_outputs, then Task A → Task B is a dependency. + // + // Example: + // %alloc = memref.alloc() + // %wo_A = taskflow.task @A write_memrefs(%alloc) ... + // %wo_B = taskflow.task @B read_memrefs(%wo_A) write_memrefs(%wo_A) ... + // Here %wo_A in B's read_memrefs → RAW dep (A→B), + // %wo_A in B's write_memrefs → WAW dep (A→B). void buildMemoryDependencies(ArrayRef tasks, DenseMap &task_map) { - // Maps each original write memref to the task that writes it. - DenseMap orig_memref_writers; + // Maps each write_outputs SSA value to the task that produces it. + DenseMap write_output_to_producer; - // First pass: records all original memref writers. for (auto task_op : tasks) { auto &task_info = task_map[task_op.getOperation()]; - for (Value orig_write : task_op.getOriginalWriteMemrefs()) { - orig_memref_writers[orig_write] = &task_info; + for (Value wo : task_op.getWriteOutputs()) { + write_output_to_producer[wo] = &task_info; } } - // Second pass: establishes memory dependencies using original memrefs. + // For each task, check if its read/write_memrefs consume another + // task's write_outputs. If so, establish a dependency edge. for (auto task_op : tasks) { auto &task_info = task_map[task_op.getOperation()]; - for (Value orig_read : task_op.getOriginalReadMemrefs()) { - auto it = orig_memref_writers.find(orig_read); - if (it != orig_memref_writers.end()) { - TaskInfo *writer = it->second; - // Don't create self-dependency. - if (writer == &task_info) - continue; - // Establishes bidirectional memory dependency. - task_info.memory_writers.push_back(writer); - writer->memory_readers.push_back(&task_info); + DenseSet seen_writers; + + auto addDependencyIfProduced = [&](Value operand) { + auto it = write_output_to_producer.find(operand); + if (it == write_output_to_producer.end()) { + return; + } + TaskInfo *writer = it->second; + if (writer == &task_info) { + return; // No self-dependency. } + if (!seen_writers.insert(writer).second) { + return; // Deduplicate (same value may appear in both read and + // write). + } + task_info.memory_writers.push_back(writer); + writer->memory_readers.push_back(&task_info); + }; + + // RAW: read_memrefs consuming a write_outputs value. + for (Value operand : task_op.getReadMemrefs()) { + addDependencyIfProduced(operand); + } + + // WAW/WAR: write_memrefs consuming a write_outputs value. + for (Value operand : task_op.getWriteMemrefs()) { + addDependencyIfProduced(operand); } } } - func::FuncOp function; + func::FuncOp func_op; }; //===----------------------------------------------------------------------===// // Fusion Candidate Identification //===----------------------------------------------------------------------===// -/// Identifies viable fusion candidates using greedy strategy. +// Identifies viable fusion candidates using greedy strategy. class FusionCandidateIdentifier { public: FusionCandidateIdentifier(DenseMap &map) : task_map(map) {} - /// Identifies all valid fusion candidates and sorts them by benefit. + // Identifies all valid fusion candidates and sorts them by benefit. SmallVector identify() { SmallVector candidates; @@ -192,7 +222,8 @@ class FusionCandidateIdentifier { } } - // Sorts candidates by fusion benefit (highest first) for greedy selection. + // Sorts candidates by fusion benefit (highest first) for greedy + // selection. llvm::sort(candidates, [](const FusionCandidate &a, const FusionCandidate &b) { return a.fusion_benefit > b.fusion_benefit; @@ -202,52 +233,66 @@ class FusionCandidateIdentifier { } private: - /// Checks if two tasks can be fused. + // Checks if two tasks can be fused. bool canFuse(TaskInfo *writer, TaskInfo *reader) { auto writer_op = writer->task_op; auto reader_op = reader->task_op; + // 0. Writers with value_outputs are not fusable. + // The fused task only produces the reader's outputs, so the + // writer's value_outputs would be lost. A future enhancement + // could propagate them through the fused task. + if (!writer_op.getValueOutputs().empty()) { + llvm::errs() << " canFuse: writer has value_outputs\n"; + return false; + } + // 1. Extracts loop nests from both tasks and checks bounds compatibility. auto writer_loops = extractOutermostLoopNest(writer_op); auto reader_loops = extractOutermostLoopNest(reader_op); if (writer_loops.empty() || reader_loops.empty()) { - LLVM_DEBUG(llvm::dbgs() << " canFuse: empty loop nest\n"); + llvm::errs() << " canFuse: empty loop nest\n"; return false; } if (!areLoopBoundsCompatible(writer_loops, reader_loops)) { - LLVM_DEBUG(llvm::dbgs() << " canFuse: incompatible loop bounds\n"); + llvm::errs() << " canFuse: incompatible loop bounds\n"; return false; } // 2. Checks that the intermediate memref is not used outside // writer/reader (e.g., not returned by the function). Value intermediate = findIntermediateMemref(writer, reader); - if (!intermediate) + if (!intermediate) { return false; + } for (Operation *user : intermediate.getUsers()) { - if (user == writer_op.getOperation() || user == reader_op.getOperation()) + if (user == writer_op.getOperation() || + user == reader_op.getOperation()) { continue; + } // Allow memref.alloc which defines it. - if (isa(user)) + if (isa(user)) { continue; - LLVM_DEBUG(llvm::dbgs() - << " canFuse: intermediate memref has external use\n"); + } + llvm::errs() << " canFuse: intermediate memref has external use\n"; return false; } // 3. No cyclic dependency: writer must not read any original memref // that reader writes (excluding the intermediate itself). for (Value w_read : writer_op.getOriginalReadMemrefs()) { - if (w_read == intermediate) + if (w_read == intermediate) { continue; + } for (Value r_write : reader_op.getOriginalWriteMemrefs()) { - if (r_write == intermediate) + if (r_write == intermediate) { continue; + } if (w_read == r_write) { - LLVM_DEBUG(llvm::dbgs() << " canFuse: cyclic dependency\n"); + llvm::errs() << " canFuse: cyclic dependency\n"; return false; } } @@ -256,8 +301,8 @@ class FusionCandidateIdentifier { return true; } - /// Extracts the outermost chain of perfectly nested affine.for loops - /// from a task body. + // Extracts the outermost chain of perfectly nested affine.for loops + // from a task body. SmallVector extractOutermostLoopNest(taskflow::TaskflowTaskOp task_op) { SmallVector loops; @@ -267,13 +312,15 @@ class FusionCandidateIdentifier { affine::AffineForOp outermost = nullptr; for (Operation &op : body) { if (auto for_op = dyn_cast(op)) { - if (outermost) + if (outermost) { return {}; // Multiple top-level loops: bail. + } outermost = for_op; } } - if (!outermost) + if (!outermost) { return {}; + } // Walks the perfectly nested chain. auto current = outermost; @@ -295,28 +342,31 @@ class FusionCandidateIdentifier { return loops; } - /// Checks if two loop nests have compatible bounds. + // Checks if two loop nests have compatible bounds. bool areLoopBoundsCompatible(SmallVector &writer_loops, SmallVector &reader_loops) { - if (writer_loops.size() != reader_loops.size()) + if (writer_loops.size() != reader_loops.size()) { return false; + } for (size_t i = 0; i < writer_loops.size(); ++i) { if (!writer_loops[i].hasConstantBounds() || - !reader_loops[i].hasConstantBounds()) + !reader_loops[i].hasConstantBounds()) { return false; + } if (writer_loops[i].getConstantLowerBound() != reader_loops[i].getConstantLowerBound() || writer_loops[i].getConstantUpperBound() != reader_loops[i].getConstantUpperBound() || - writer_loops[i].getStepAsInt() != reader_loops[i].getStepAsInt()) + writer_loops[i].getStepAsInt() != reader_loops[i].getStepAsInt()) { return false; + } } return true; } - /// Finds the intermediate memref (original alloc) between writer and reader. - /// Uses original_write/read_memrefs since the IR passes SSA results - /// (not raw allocs) as read_memrefs. + // Finds the intermediate memref (original alloc) between writer and + // reader. Uses original_write/read_memrefs since the IR passes SSA results + // (not raw allocs) as read_memrefs. Value findIntermediateMemref(TaskInfo *writer, TaskInfo *reader) { auto writer_op = writer->task_op; auto reader_op = reader->task_op; @@ -330,16 +380,18 @@ class FusionCandidateIdentifier { return nullptr; } - /// Calculates the fusion benefit score for greedy selection. + // Calculates the fusion benefit score for greedy selection. int calculateFusionBenefit(TaskInfo *writer, TaskInfo *reader) { int benefit = 0; // Base benefit: eliminates one memref allocation. benefit += 100; - // Bonus for element-wise operations (same loop bounds, same memref shape). - if (writer->write_memrefs.size() == 1 && reader->read_memrefs.size() == 1) + // Bonus for element-wise operations (same loop bounds, same memref + // shape). + if (writer->write_memrefs.size() == 1 && reader->read_memrefs.size() == 1) { benefit += 50; + } return benefit; } @@ -351,20 +403,19 @@ class FusionCandidateIdentifier { // Task Fuser - Performs the actual fusion transformation //===----------------------------------------------------------------------===// -/// Fuses two tasks connected by a memory dependency into a single task. +// Fuses two tasks connected by a memory dependency into a single task. class TaskFuser { public: TaskFuser(func::FuncOp func) : function(func) {} - /// Performs fusion of a candidate pair. Returns true if fusion succeeded. + // Performs fusion of a candidate pair. Returns true if fusion succeeded. bool performFusion(FusionCandidate &candidate) { auto writer_op = candidate.memory_writer->task_op; auto reader_op = candidate.memory_reader->task_op; Value intermediate = candidate.intermediate_memref; - LLVM_DEBUG(llvm::dbgs() - << " Fusing writer: " << writer_op.getTaskName() - << " + reader: " << reader_op.getTaskName() << "\n"); + llvm::errs() << " Fusing writer: " << writer_op.getTaskName() + << " + reader: " << reader_op.getTaskName() << "\n"; // Step 1: Builds merged operand lists. SmallVector fused_read_memrefs; @@ -379,12 +430,16 @@ class TaskFuser { fused_original_write_memrefs); // Step 2: Builds the result types (same as reader's outputs). + // Writer's value_outputs are not included because canFuse rejects + // writers with value_outputs. SmallVector write_output_types; SmallVector value_output_types; - for (Value v : reader_op.getWriteOutputs()) + for (Value v : reader_op.getWriteOutputs()) { write_output_types.push_back(v.getType()); - for (Value v : reader_op.getValueOutputs()) + } + for (Value v : reader_op.getValueOutputs()) { value_output_types.push_back(v.getType()); + } // Step 3: Creates the fused task. OpBuilder builder(reader_op); @@ -408,15 +463,15 @@ class TaskFuser { // Step 5: Replaces uses and cleans up. replaceUsesAndCleanup(writer_op, reader_op, fused_task, intermediate); - LLVM_DEBUG(llvm::dbgs() << " Fusion succeeded: " << fused_name << "\n"); + llvm::errs() << " Fusion succeeded: " << fused_name << "\n"; return true; } private: - /// Builds merged operand lists for the fused task. - /// The intermediate is the original alloc value. We must use - /// original_read/write_memrefs to identify which operands to exclude - /// from the reader (since reader's read_memrefs contain SSA results). + // Builds merged operand lists for the fused task. + // The intermediate is the original alloc value. We must use + // original_read/write_memrefs to identify which operands to exclude + // from the reader (since reader's read_memrefs contain SSA results). void buildMergedOperands(taskflow::TaskflowTaskOp writer_op, taskflow::TaskflowTaskOp reader_op, Value intermediate, SmallVector &fused_reads, @@ -432,8 +487,9 @@ class TaskFuser { for (unsigned i = 0; i < writer_reads.size(); ++i) { Value orig = (i < writer_orig_reads.size()) ? writer_orig_reads[i] : writer_reads[i]; - if (orig != intermediate && seen.insert(writer_reads[i]).second) + if (orig != intermediate && seen.insert(writer_reads[i]).second) { fused_reads.push_back(writer_reads[i]); + } } auto reader_reads = reader_op.getReadMemrefs(); @@ -441,8 +497,9 @@ class TaskFuser { for (unsigned i = 0; i < reader_reads.size(); ++i) { Value orig = (i < reader_orig_reads.size()) ? reader_orig_reads[i] : reader_reads[i]; - if (orig != intermediate && seen.insert(reader_reads[i]).second) + if (orig != intermediate && seen.insert(reader_reads[i]).second) { fused_reads.push_back(reader_reads[i]); + } } // write_memrefs = reader.writes ∪ (writer.writes - intermediate) @@ -452,8 +509,9 @@ class TaskFuser { for (unsigned i = 0; i < reader_writes.size(); ++i) { Value orig = (i < reader_orig_writes.size()) ? reader_orig_writes[i] : reader_writes[i]; - if (orig != intermediate && seen.insert(reader_writes[i]).second) + if (orig != intermediate && seen.insert(reader_writes[i]).second) { fused_writes.push_back(reader_writes[i]); + } } auto writer_writes = writer_op.getWriteMemrefs(); @@ -461,40 +519,48 @@ class TaskFuser { for (unsigned i = 0; i < writer_writes.size(); ++i) { Value orig = (i < writer_orig_writes.size()) ? writer_orig_writes[i] : writer_writes[i]; - if (orig != intermediate && seen.insert(writer_writes[i]).second) + if (orig != intermediate && seen.insert(writer_writes[i]).second) { fused_writes.push_back(writer_writes[i]); + } } // value_inputs = writer.values ∪ reader.values - for (Value v : writer_op.getValueInputs()) + for (Value v : writer_op.getValueInputs()) { fused_values.push_back(v); - for (Value v : reader_op.getValueInputs()) + } + for (Value v : reader_op.getValueInputs()) { fused_values.push_back(v); + } - // original_read/write_memrefs: same merge rules (using originals directly). + // original_read/write_memrefs: same merge rules (using originals + // directly). seen.clear(); for (Value v : writer_op.getOriginalReadMemrefs()) { - if (v != intermediate && seen.insert(v).second) + if (v != intermediate && seen.insert(v).second) { fused_orig_reads.push_back(v); + } } for (Value v : reader_op.getOriginalReadMemrefs()) { - if (v != intermediate && seen.insert(v).second) + if (v != intermediate && seen.insert(v).second) { fused_orig_reads.push_back(v); + } } seen.clear(); for (Value v : reader_op.getOriginalWriteMemrefs()) { - if (v != intermediate && seen.insert(v).second) + if (v != intermediate && seen.insert(v).second) { fused_orig_writes.push_back(v); + } } for (Value v : writer_op.getOriginalWriteMemrefs()) { - if (v != intermediate && seen.insert(v).second) + if (v != intermediate && seen.insert(v).second) { fused_orig_writes.push_back(v); + } } } - /// Builds the fused task body by merging writer and reader loop nests. - /// Returns false if fusion fails (e.g., unexpected IR structure). + // Builds the fused task body by merging writer and reader loop nests. + // Returns false if fusion fails (e.g., unexpected IR structure). bool buildFusedBody(taskflow::TaskflowTaskOp fused_task, taskflow::TaskflowTaskOp writer_op, taskflow::TaskflowTaskOp reader_op, Value intermediate, @@ -506,12 +572,15 @@ class TaskFuser { fused_task.getBody().push_back(fused_block); // Block args: read_memrefs, write_memrefs, value_inputs. - for (Value v : fused_reads) + for (Value v : fused_reads) { fused_block->addArgument(v.getType(), v.getLoc()); - for (Value v : fused_writes) + } + for (Value v : fused_writes) { fused_block->addArgument(v.getType(), v.getLoc()); - for (Value v : fused_values) + } + for (Value v : fused_values) { fused_block->addArgument(v.getType(), v.getLoc()); + } // Builds a mapping from writer/reader block args to fused block args. IRMapping writer_mapping; @@ -542,8 +611,9 @@ class TaskFuser { break; } } - if (!writer_outer_loop) + if (!writer_outer_loop) { return false; + } // Finds the reader's outermost affine.for. affine::AffineForOp reader_outer_loop = nullptr; @@ -553,8 +623,9 @@ class TaskFuser { break; } } - if (!reader_outer_loop) + if (!reader_outer_loop) { return false; + } // Clones the writer's entire loop nest. Operation *cloned_writer = @@ -571,8 +642,9 @@ class TaskFuser { break; } } - if (!nested) + if (!nested) { break; + } innermost_writer = nested; } @@ -599,15 +671,15 @@ class TaskFuser { if (auto store_op = dyn_cast(op)) { // Checks if this store writes to the intermediate memref's // block arg (mapped from the writer's original arg). - // The stored-to memref is the writer's block arg for the intermediate. + // The stored-to memref is the writer's block arg for the + // intermediate. store_value = store_op.getValueToStore(); store_to_intermediate = &op; } } if (!store_value) { - LLVM_DEBUG(llvm::dbgs() - << " No store to intermediate found in writer\n"); + llvm::errs() << " No store to intermediate found in writer\n"; return false; } @@ -623,8 +695,9 @@ class TaskFuser { break; } } - if (!nested) + if (!nested) { break; + } reader_innermost = nested; } @@ -635,13 +708,15 @@ class TaskFuser { // Positions before the affine.yield terminator if it exists. if (!innermost_writer.getBody()->empty()) { Operation *terminator = innermost_writer.getBody()->getTerminator(); - if (terminator) + if (terminator) { inner_builder.setInsertionPoint(terminator); + } } for (Operation &op : reader_innermost.getBody()->getOperations()) { - if (op.hasTrait()) + if (op.hasTrait()) { continue; + } if (auto load_op = dyn_cast(op)) { // Checks if this load reads from the intermediate memref. @@ -682,15 +757,17 @@ class TaskFuser { } // Removes the store to intermediate (no longer needed). - if (store_to_intermediate) + if (store_to_intermediate) { store_to_intermediate->erase(); + } // Step 5: Creates the yield for the fused task. // Yields the reader's output memrefs. // Remove any existing terminator (if the block already has one). if (fused_block->mightHaveTerminator()) { - if (auto *yield_point = fused_block->getTerminator()) + if (auto *yield_point = fused_block->getTerminator()) { yield_point->erase(); + } } OpBuilder yield_builder(fused_block, fused_block->end()); @@ -703,16 +780,18 @@ class TaskFuser { // Maps reader yield's memory results to fused block args. for (Value v : reader_yield.getMemoryResults()) { - if (reader_mapping.contains(v)) + if (reader_mapping.contains(v)) { yield_writes.push_back(reader_mapping.lookup(v)); - else + } else { yield_writes.push_back(v); + } } for (Value v : reader_yield.getValueResults()) { - if (reader_mapping.contains(v)) + if (reader_mapping.contains(v)) { yield_values.push_back(reader_mapping.lookup(v)); - else + } else { yield_values.push_back(v); + } } yield_builder.create(reader_op.getLoc(), @@ -721,7 +800,7 @@ class TaskFuser { return true; } - /// Maps a task's block args to the corresponding fused block args. + // Maps a task's block args to the corresponding fused block args. void mapBlockArgs(taskflow::TaskflowTaskOp task_op, Block &original_body, Block *fused_block, unsigned &fused_arg_idx, IRMapping &mapping, Value intermediate, @@ -789,31 +868,34 @@ class TaskFuser { } } - /// Finds the index of an outer value in the fused block's argument list. - /// Returns -1 if not found. + // Finds the index of an outer value in the fused block's argument list. + // Returns -1 if not found. int findInFusedArgs(Value outer_val, ArrayRef fused_reads, ArrayRef fused_writes, ArrayRef fused_values) { unsigned idx = 0; for (Value v : fused_reads) { - if (v == outer_val) + if (v == outer_val) { return idx; + } idx++; } for (Value v : fused_writes) { - if (v == outer_val) + if (v == outer_val) { return idx; + } idx++; } for (Value v : fused_values) { - if (v == outer_val) + if (v == outer_val) { return idx; + } idx++; } return -1; } - /// Gets the chain of nested affine.for ops starting from the outermost. + // Gets the chain of nested affine.for ops starting from the outermost. SmallVector getLoopChain(affine::AffineForOp outermost) { SmallVector chain; auto current = outermost; @@ -831,8 +913,8 @@ class TaskFuser { return chain; } - /// Replaces uses of original tasks' results with fused task results - /// and erases original ops. + // Replaces uses of original tasks' results with fused task results + // and erases original ops. void replaceUsesAndCleanup(taskflow::TaskflowTaskOp writer_op, taskflow::TaskflowTaskOp reader_op, taskflow::TaskflowTaskOp fused_task, @@ -862,8 +944,9 @@ class TaskFuser { // Erases the intermediate memref allocation if it's now dead. if (auto alloc_op = intermediate.getDefiningOp()) { - if (alloc_op.getResult().use_empty()) + if (alloc_op.getResult().use_empty()) { alloc_op.erase(); + } } } @@ -875,18 +958,28 @@ class TaskFuser { //===----------------------------------------------------------------------===// struct MemoryAccessStreamingFusionPass - : public impl::MemoryAccessStreamingFusionBase< - MemoryAccessStreamingFusionPass> { + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemoryAccessStreamingFusionPass) + + StringRef getArgument() const final { + return "memory-access-streaming-fusion"; + } + + StringRef getDescription() const final { + return "Memory Access Streaming Fusion"; + } void runOnOperation() override { func::FuncOp func = getOperation(); - LLVM_DEBUG(llvm::dbgs() - << "Running MemoryAccessStreamingFusion on function: " - << func.getName() << "\n"); + llvm::errs() << "[MemoryAccessStreamingFusion] Running " + "MemoryAccessStreamingFusion on function: " + << func.getName() << "\n"; // Iterative fusion: re-analyze after each round to catch chains. - // e.g., A→B→C: first round fuses A+B, second round fuses (A+B)+C. + // e.g., Task A → Task B → Task C: first round fuses A+B, second round + // fuses (A+B)+C. unsigned total_fusions = 0; constexpr unsigned kMaxIterations = 100; @@ -896,18 +989,18 @@ struct MemoryAccessStreamingFusionPass MemoryDependencyAnalysis analysis(func); analysis.analyze(task_map); - LLVM_DEBUG(llvm::dbgs() << "Iteration " << iter << ": Found " - << task_map.size() << " tasks\n"); + llvm::errs() << "Iteration " << iter << ": Found " << task_map.size() + << " tasks\n"; // Identifies fusion candidates. FusionCandidateIdentifier identifier(task_map); auto candidates = identifier.identify(); - LLVM_DEBUG(llvm::dbgs() - << "Found " << candidates.size() << " fusion candidates\n"); + llvm::errs() << "Found " << candidates.size() << " fusion candidates\n"; - if (candidates.empty()) + if (candidates.empty()) { break; + } // Performs greedy fusion for this round. DenseSet fused_tasks; @@ -920,11 +1013,12 @@ struct MemoryAccessStreamingFusionPass // Skips if either task was already consumed by a previous fusion // in this round. - if (fused_tasks.count(writer_op) || fused_tasks.count(reader_op)) + if (fused_tasks.count(writer_op) || fused_tasks.count(reader_op)) { continue; + } - LLVM_DEBUG(llvm::dbgs() << "Attempting to fuse tasks (benefit: " - << candidate.fusion_benefit << ")\n"); + llvm::errs() << "Attempting to fuse tasks (benefit: " + << candidate.fusion_benefit << ")\n"; if (fuser.performFusion(candidate)) { fused_tasks.insert(writer_op); @@ -933,17 +1027,18 @@ struct MemoryAccessStreamingFusionPass } } - LLVM_DEBUG(llvm::dbgs() << "Round " << iter << ": fused " << round_fusions - << " task pairs\n"); + llvm::errs() << "Round " << iter << ": fused " << round_fusions + << " task pairs\n"; total_fusions += round_fusions; // If no fusions happened this round, we've converged. - if (round_fusions == 0) + if (round_fusions == 0) { break; + } } - LLVM_DEBUG(llvm::dbgs() << "Total fusions: " << total_fusions << "\n"); + llvm::errs() << "Total fusions: " << total_fusions << "\n"; } }; @@ -953,9 +1048,6 @@ struct MemoryAccessStreamingFusionPass // Pass Registration //===----------------------------------------------------------------------===// -std::unique_ptr createMemoryAccessStreamingFusionPass() { +std::unique_ptr mlir::taskflow::createMemoryAccessStreamingFusionPass() { return std::make_unique(); } - -} // namespace taskflow -} // namespace mlir From e32bc6d8f17b9642eb62fe4b5a92069b222ba067 Mon Sep 17 00:00:00 2001 From: ShangkunLI Date: Wed, 11 Feb 2026 23:54:33 +0800 Subject: [PATCH 3/3] add test --- .../MemoryAccessStreamingFusion.cpp | 17 +- .../taskflow/multi-nested/multi-nested.mlir | 65 ++++++ .../taskflow/resnet/simple_resnet_tosa.mlir | 189 ++++++++++++++++++ 3 files changed, 264 insertions(+), 7 deletions(-) diff --git a/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp b/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp index f05f68f9..f99747fc 100644 --- a/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp +++ b/lib/TaskflowDialect/Transforms/Optimizations/MemoryAccessStreamingFusion.cpp @@ -989,14 +989,15 @@ struct MemoryAccessStreamingFusionPass MemoryDependencyAnalysis analysis(func); analysis.analyze(task_map); - llvm::errs() << "Iteration " << iter << ": Found " << task_map.size() - << " tasks\n"; + llvm::errs() << "[MemoryAccessStreamingFusion] Iteration " << iter + << ": Found " << task_map.size() << " tasks\n"; // Identifies fusion candidates. FusionCandidateIdentifier identifier(task_map); auto candidates = identifier.identify(); - llvm::errs() << "Found " << candidates.size() << " fusion candidates\n"; + llvm::errs() << "[MemoryAccessStreamingFusion] Found " + << candidates.size() << " fusion candidates\n"; if (candidates.empty()) { break; @@ -1017,7 +1018,8 @@ struct MemoryAccessStreamingFusionPass continue; } - llvm::errs() << "Attempting to fuse tasks (benefit: " + llvm::errs() << "[MemoryAccessStreamingFusion] Attempting to fuse " + "tasks (benefit: " << candidate.fusion_benefit << ")\n"; if (fuser.performFusion(candidate)) { @@ -1027,8 +1029,8 @@ struct MemoryAccessStreamingFusionPass } } - llvm::errs() << "Round " << iter << ": fused " << round_fusions - << " task pairs\n"; + llvm::errs() << "[MemoryAccessStreamingFusion] Round " << iter + << ": fused " << round_fusions << " task pairs\n"; total_fusions += round_fusions; @@ -1038,7 +1040,8 @@ struct MemoryAccessStreamingFusionPass } } - llvm::errs() << "Total fusions: " << total_fusions << "\n"; + llvm::errs() << "[MemoryAccessStreamingFusion] Total fusions: " + << total_fusions << "\n"; } }; diff --git a/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir b/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir index e6376f44..bcdbbe86 100644 --- a/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir +++ b/test/multi-cgra/taskflow/multi-nested/multi-nested.mlir @@ -12,6 +12,12 @@ // RUN: -o %t.taskflow.mlir // RUN: FileCheck %s --input-file=%t.taskflow.mlir --check-prefixes=TASKFLOW +// RUN: mlir-neura-opt %s --affine-loop-tree-serialization \ +// RUN: --convert-affine-to-taskflow \ +// RUN: --memory-access-streaming-fusion \ +// RUN: -o %t.stream.mlir +// RUN: FileCheck %s --input-file=%t.stream.mlir --check-prefixes=STREAM + // RUN: mlir-neura-opt %s --affine-loop-tree-serialization \ // RUN: --convert-affine-to-taskflow \ // RUN: --construct-hyperblock-from-task \ @@ -245,6 +251,65 @@ module attributes {} { // TASKFLOW-NEXT: } // TASKFLOW-NEXT: } +// STREAM: module { +// STREAM-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// STREAM-NEXT: %write_outputs = taskflow.task @Task_1 read_memrefs(%arg1, %arg2 : memref, memref) write_memrefs(%arg6 : memref) [original_read_memrefs(%arg1, %arg2 : memref, memref), original_write_memrefs(%arg6 : memref)] : (memref, memref, memref) -> (memref) { +// STREAM-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref): +// STREAM-NEXT: affine.for %arg13 = 0 to 4 { +// STREAM-NEXT: affine.for %arg14 = 0 to 8 { +// STREAM-NEXT: affine.for %arg15 = 0 to 5 { +// STREAM-NEXT: %1 = affine.load %arg10[%arg13, %arg14, %arg15] : memref +// STREAM-NEXT: %2 = affine.load %arg11[%arg13, %arg14, %arg15] : memref +// STREAM-NEXT: %3 = arith.addi %1, %2 : i32 +// STREAM-NEXT: affine.store %3, %arg12[%arg15] : memref +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg12 : memref) +// STREAM-NEXT: } +// STREAM-NEXT: %write_outputs_0 = taskflow.task @Task_0_Task_2_fused read_memrefs(%arg0, %write_outputs, %arg9 : memref, memref, memref) write_memrefs(%arg9 : memref) [original_read_memrefs(%arg0, %arg6, %arg9 : memref, memref, memref), original_write_memrefs(%arg9 : memref)] : (memref, memref, memref, memref) -> (memref) { +// STREAM-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref, %arg13: memref): +// STREAM-NEXT: affine.for %arg14 = 0 to 4 { +// STREAM-NEXT: affine.for %arg15 = 0 to 8 { +// STREAM-NEXT: affine.for %arg16 = 0 to 6 { +// STREAM-NEXT: %1 = affine.load %arg10[%arg14, %arg15, %arg16] : memref +// STREAM-NEXT: %2 = affine.load %arg11[%arg16] : memref +// STREAM-NEXT: %3 = arith.addi %1, %2 : i32 +// STREAM-NEXT: %4 = affine.load %arg12[0] : memref +// STREAM-NEXT: %5 = arith.addi %4, %3 : i32 +// STREAM-NEXT: affine.store %5, %arg12[0] : memref +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg12 : memref) +// STREAM-NEXT: } +// STREAM-NEXT: %write_outputs_1 = taskflow.task @Task_3 read_memrefs(%arg3 : memref) write_memrefs(%arg7 : memref) [original_read_memrefs(%arg3 : memref), original_write_memrefs(%arg7 : memref)] : (memref, memref) -> (memref) { +// STREAM-NEXT: ^bb0(%arg10: memref, %arg11: memref): +// STREAM-NEXT: affine.for %arg12 = 0 to 4 { +// STREAM-NEXT: affine.for %arg13 = 0 to 7 { +// STREAM-NEXT: %1 = affine.load %arg10[%arg12, %arg13] : memref +// STREAM-NEXT: affine.store %1, %arg11[%arg13] : memref +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg11 : memref) +// STREAM-NEXT: } +// STREAM-NEXT: %write_outputs_2 = taskflow.task @Task_4 read_memrefs(%arg4, %write_outputs_1 : memref, memref) write_memrefs(%arg8 : memref) [original_read_memrefs(%arg4, %arg7 : memref, memref), original_write_memrefs(%arg8 : memref)] : (memref, memref, memref) -> (memref) { +// STREAM-NEXT: ^bb0(%arg10: memref, %arg11: memref, %arg12: memref): +// STREAM-NEXT: affine.for %arg13 = 0 to 4 { +// STREAM-NEXT: affine.for %arg14 = 0 to 9 { +// STREAM-NEXT: %1 = affine.load %arg10[%arg13, %arg14] : memref +// STREAM-NEXT: %2 = affine.load %arg11[%arg14] : memref +// STREAM-NEXT: %3 = arith.addi %1, %2 : i32 +// STREAM-NEXT: affine.store %3, %arg12[%arg14] : memref +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg12 : memref) +// STREAM-NEXT: } +// STREAM-NEXT: %0 = affine.load %write_outputs_0[0] : memref +// STREAM-NEXT: return %0 : i32 +// STREAM-NEXT: } +// STREAM-NEXT: } + // KERNEL: module { // KERNEL-NEXT: func.func @_Z21pureNestedLoopExamplePA8_A6_iPA8_A5_iS4_PA7_iPA9_iPiS9_S9_S9_S9_(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref, %arg4: memref, %arg5: memref, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) -> i32 attributes {llvm.linkage = #llvm.linkage} { // KERNEL-NEXT: %write_outputs = taskflow.task @Task_0 read_memrefs(%arg0 : memref) write_memrefs(%arg5 : memref) [original_read_memrefs(%arg0 : memref), original_write_memrefs(%arg5 : memref)] : (memref, memref) -> (memref) { diff --git a/test/multi-cgra/taskflow/resnet/simple_resnet_tosa.mlir b/test/multi-cgra/taskflow/resnet/simple_resnet_tosa.mlir index abc993f5..83dcb02a 100644 --- a/test/multi-cgra/taskflow/resnet/simple_resnet_tosa.mlir +++ b/test/multi-cgra/taskflow/resnet/simple_resnet_tosa.mlir @@ -7,6 +7,14 @@ // RUN: -o %t.kernel.mlir // RUN: FileCheck %s --input-file=%t.kernel.mlir --check-prefixes=KERNEL +// RUN: mlir-neura-opt %t.affine.mlir \ +// RUN: --affine-loop-tree-serialization \ +// RUN: --affine-loop-perfection \ +// RUN: --convert-affine-to-taskflow \ +// RUN: --memory-access-streaming-fusion \ +// RUN: -o %t.stream.mlir +// RUN: FileCheck %s --input-file=%t.stream.mlir --check-prefixes=STREAM + module attributes {torch.debug_module_name = "SimpleResNetBlock"} { func.func @forward(%arg0: tensor<1x64x8x8xf32>) -> tensor<1x64x8x8xf32> { %0 = "tosa.const"() <{value = dense<"0x7BEEA13C"> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32> @@ -486,3 +494,184 @@ module attributes {torch.debug_module_name = "SimpleResNetBlock"} { // KERNEL-NEXT: } // KERNEL-NEXT: } +// STREAM: module attributes {torch.debug_module_name = "SimpleResNetBlock"} { +// STREAM-NEXT: memref.global "private" constant @__constant_64xf32 : memref<64xf32> = dense<0.000000e+00> {alignment = 64 : i64} +// STREAM-NEXT: memref.global "private" constant @__constant_64x3x3x64xf32_0 : memref<64x3x3x64xf32> = dense<-0.0151730878> {alignment = 64 : i64} +// STREAM-NEXT: memref.global "private" constant @__constant_64x3x3x64xf32 : memref<64x3x3x64xf32> = dense<0.0197670367> {alignment = 64 : i64} +// STREAM-NEXT: func.func @forward(%arg0: memref<1x64x8x8xf32>) -> memref<1x64x8x8xf32> { +// STREAM-NEXT: %cst = arith.constant 0.0197670367 : f32 +// STREAM-NEXT: %cst_0 = arith.constant -0.0151730878 : f32 +// STREAM-NEXT: %cst_1 = arith.constant 3.40282347E+38 : f32 +// STREAM-NEXT: %cst_2 = arith.constant 0.000000e+00 : f32 +// STREAM-NEXT: %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x8x8x64xf32> +// STREAM-NEXT: %write_outputs = taskflow.task @Task_0 read_memrefs(%arg0 : memref<1x64x8x8xf32>) write_memrefs(%alloc : memref<1x8x8x64xf32>) [original_read_memrefs(%arg0 : memref<1x64x8x8xf32>), original_write_memrefs(%alloc : memref<1x8x8x64xf32>)] : (memref<1x64x8x8xf32>, memref<1x8x8x64xf32>) -> (memref<1x8x8x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x64x8x8xf32>, %arg2: memref<1x8x8x64xf32>): +// STREAM-NEXT: affine.for %arg3 = 0 to 1 { +// STREAM-NEXT: affine.for %arg4 = 0 to 8 { +// STREAM-NEXT: affine.for %arg5 = 0 to 8 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: %0 = affine.load %arg1[%arg3, %arg6, %arg4, %arg5] : memref<1x64x8x8xf32> +// STREAM-NEXT: affine.store %0, %arg2[%arg3, %arg4, %arg5, %arg6] : memref<1x8x8x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg2 : memref<1x8x8x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<1x10x10x64xf32> +// STREAM-NEXT: %write_outputs_4 = taskflow.task @Task_1 write_memrefs(%alloc_3 : memref<1x10x10x64xf32>) value_inputs(%cst_2 : f32) [original_write_memrefs(%alloc_3 : memref<1x10x10x64xf32>)] : (memref<1x10x10x64xf32>, f32) -> (memref<1x10x10x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x10x10x64xf32>, %arg2: f32): +// STREAM-NEXT: affine.for %arg3 = 0 to 1 { +// STREAM-NEXT: affine.for %arg4 = 0 to 10 { +// STREAM-NEXT: affine.for %arg5 = 0 to 10 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: affine.store %arg2, %arg1[%arg3, %arg4, %arg5, %arg6] : memref<1x10x10x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg1 : memref<1x10x10x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<1x8x8x64xf32> +// STREAM-NEXT: %write_outputs_6 = taskflow.task @Task_2 write_memrefs(%alloc_5 : memref<1x8x8x64xf32>) value_inputs(%cst_2 : f32) [original_write_memrefs(%alloc_5 : memref<1x8x8x64xf32>)] : (memref<1x8x8x64xf32>, f32) -> (memref<1x8x8x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x8x8x64xf32>, %arg2: f32): +// STREAM-NEXT: affine.for %arg3 = 0 to 1 { +// STREAM-NEXT: affine.for %arg4 = 0 to 8 { +// STREAM-NEXT: affine.for %arg5 = 0 to 8 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: affine.store %arg2, %arg1[%arg3, %arg4, %arg5, %arg6] : memref<1x8x8x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg1 : memref<1x8x8x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %write_outputs_7 = taskflow.task @Task_3 read_memrefs(%write_outputs_4, %write_outputs_6 : memref<1x10x10x64xf32>, memref<1x8x8x64xf32>) write_memrefs(%write_outputs_6 : memref<1x8x8x64xf32>) value_inputs(%cst_0 : f32) [original_read_memrefs(%alloc_3, %alloc_5 : memref<1x10x10x64xf32>, memref<1x8x8x64xf32>), original_write_memrefs(%alloc_5 : memref<1x8x8x64xf32>)] : (memref<1x10x10x64xf32>, memref<1x8x8x64xf32>, memref<1x8x8x64xf32>, f32) -> (memref<1x8x8x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x10x10x64xf32>, %arg2: memref<1x8x8x64xf32>, %arg3: memref<1x8x8x64xf32>, %arg4: f32): +// STREAM-NEXT: affine.for %arg5 = 0 to 1 { +// STREAM-NEXT: affine.for %arg6 = 0 to 8 { +// STREAM-NEXT: affine.for %arg7 = 0 to 8 { +// STREAM-NEXT: affine.for %arg8 = 0 to 64 { +// STREAM-NEXT: affine.for %arg9 = 0 to 3 { +// STREAM-NEXT: affine.for %arg10 = 0 to 3 { +// STREAM-NEXT: affine.for %arg11 = 0 to 64 { +// STREAM-NEXT: %0 = affine.load %arg1[%arg5, %arg6 + %arg9, %arg7 + %arg10, %arg11] : memref<1x10x10x64xf32> +// STREAM-NEXT: %1 = affine.load %arg3[%arg5, %arg6, %arg7, %arg8] : memref<1x8x8x64xf32> +// STREAM-NEXT: %2 = arith.mulf %0, %arg4 : f32 +// STREAM-NEXT: %3 = arith.addf %1, %2 : f32 +// STREAM-NEXT: affine.store %3, %arg3[%arg5, %arg6, %arg7, %arg8] : memref<1x8x8x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg3 : memref<1x8x8x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<1x64x8x8xf32> +// STREAM-NEXT: %write_outputs_9 = taskflow.task @Task_4_Task_5_fused read_memrefs(%write_outputs_7 : memref<1x8x8x64xf32>) write_memrefs(%alloc_8 : memref<1x64x8x8xf32>) value_inputs(%cst_1, %cst_2 : f32, f32) [original_read_memrefs(%alloc_5 : memref<1x8x8x64xf32>), original_write_memrefs(%alloc_8 : memref<1x64x8x8xf32>)] : (memref<1x8x8x64xf32>, memref<1x64x8x8xf32>, f32, f32) -> (memref<1x64x8x8xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x8x8x64xf32>, %arg2: memref<1x64x8x8xf32>, %arg3: f32, %arg4: f32): +// STREAM-NEXT: affine.for %arg5 = 0 to 1 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: affine.for %arg7 = 0 to 8 { +// STREAM-NEXT: affine.for %arg8 = 0 to 8 { +// STREAM-NEXT: %0 = affine.load %arg1[%arg5, %arg7, %arg8, %arg6] : memref<1x8x8x64xf32> +// STREAM-NEXT: %1 = arith.minimumf %0, %arg3 : f32 +// STREAM-NEXT: %2 = arith.maximumf %1, %arg4 : f32 +// STREAM-NEXT: affine.store %2, %arg2[%arg5, %arg6, %arg7, %arg8] : memref<1x64x8x8xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg2 : memref<1x64x8x8xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<1x8x8x64xf32> +// STREAM-NEXT: %write_outputs_11 = taskflow.task @Task_6 read_memrefs(%write_outputs_9 : memref<1x64x8x8xf32>) write_memrefs(%alloc_10 : memref<1x8x8x64xf32>) [original_read_memrefs(%alloc_8 : memref<1x64x8x8xf32>), original_write_memrefs(%alloc_10 : memref<1x8x8x64xf32>)] : (memref<1x64x8x8xf32>, memref<1x8x8x64xf32>) -> (memref<1x8x8x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x64x8x8xf32>, %arg2: memref<1x8x8x64xf32>): +// STREAM-NEXT: affine.for %arg3 = 0 to 1 { +// STREAM-NEXT: affine.for %arg4 = 0 to 8 { +// STREAM-NEXT: affine.for %arg5 = 0 to 8 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: %0 = affine.load %arg1[%arg3, %arg6, %arg4, %arg5] : memref<1x64x8x8xf32> +// STREAM-NEXT: affine.store %0, %arg2[%arg3, %arg4, %arg5, %arg6] : memref<1x8x8x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg2 : memref<1x8x8x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<1x10x10x64xf32> +// STREAM-NEXT: %write_outputs_13 = taskflow.task @Task_7 write_memrefs(%alloc_12 : memref<1x10x10x64xf32>) value_inputs(%cst_2 : f32) [original_write_memrefs(%alloc_12 : memref<1x10x10x64xf32>)] : (memref<1x10x10x64xf32>, f32) -> (memref<1x10x10x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x10x10x64xf32>, %arg2: f32): +// STREAM-NEXT: affine.for %arg3 = 0 to 1 { +// STREAM-NEXT: affine.for %arg4 = 0 to 10 { +// STREAM-NEXT: affine.for %arg5 = 0 to 10 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: affine.store %arg2, %arg1[%arg3, %arg4, %arg5, %arg6] : memref<1x10x10x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg1 : memref<1x10x10x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<1x8x8x64xf32> +// STREAM-NEXT: %write_outputs_15 = taskflow.task @Task_8 write_memrefs(%alloc_14 : memref<1x8x8x64xf32>) value_inputs(%cst_2 : f32) [original_write_memrefs(%alloc_14 : memref<1x8x8x64xf32>)] : (memref<1x8x8x64xf32>, f32) -> (memref<1x8x8x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x8x8x64xf32>, %arg2: f32): +// STREAM-NEXT: affine.for %arg3 = 0 to 1 { +// STREAM-NEXT: affine.for %arg4 = 0 to 8 { +// STREAM-NEXT: affine.for %arg5 = 0 to 8 { +// STREAM-NEXT: affine.for %arg6 = 0 to 64 { +// STREAM-NEXT: affine.store %arg2, %arg1[%arg3, %arg4, %arg5, %arg6] : memref<1x8x8x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg1 : memref<1x8x8x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %write_outputs_16 = taskflow.task @Task_9 read_memrefs(%write_outputs_13, %write_outputs_15 : memref<1x10x10x64xf32>, memref<1x8x8x64xf32>) write_memrefs(%write_outputs_15 : memref<1x8x8x64xf32>) value_inputs(%cst : f32) [original_read_memrefs(%alloc_12, %alloc_14 : memref<1x10x10x64xf32>, memref<1x8x8x64xf32>), original_write_memrefs(%alloc_14 : memref<1x8x8x64xf32>)] : (memref<1x10x10x64xf32>, memref<1x8x8x64xf32>, memref<1x8x8x64xf32>, f32) -> (memref<1x8x8x64xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x10x10x64xf32>, %arg2: memref<1x8x8x64xf32>, %arg3: memref<1x8x8x64xf32>, %arg4: f32): +// STREAM-NEXT: affine.for %arg5 = 0 to 1 { +// STREAM-NEXT: affine.for %arg6 = 0 to 8 { +// STREAM-NEXT: affine.for %arg7 = 0 to 8 { +// STREAM-NEXT: affine.for %arg8 = 0 to 64 { +// STREAM-NEXT: affine.for %arg9 = 0 to 3 { +// STREAM-NEXT: affine.for %arg10 = 0 to 3 { +// STREAM-NEXT: affine.for %arg11 = 0 to 64 { +// STREAM-NEXT: %0 = affine.load %arg1[%arg5, %arg6 + %arg9, %arg7 + %arg10, %arg11] : memref<1x10x10x64xf32> +// STREAM-NEXT: %1 = affine.load %arg3[%arg5, %arg6, %arg7, %arg8] : memref<1x8x8x64xf32> +// STREAM-NEXT: %2 = arith.mulf %0, %arg4 : f32 +// STREAM-NEXT: %3 = arith.addf %1, %2 : f32 +// STREAM-NEXT: affine.store %3, %arg3[%arg5, %arg6, %arg7, %arg8] : memref<1x8x8x64xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg3 : memref<1x8x8x64xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: %alloc_17 = memref.alloc() {alignment = 64 : i64} : memref<1x64x8x8xf32> +// STREAM-NEXT: %write_outputs_18 = taskflow.task @Task_10_Task_11_Task_12_fused_fused read_memrefs(%write_outputs_16, %arg0 : memref<1x8x8x64xf32>, memref<1x64x8x8xf32>) write_memrefs(%alloc_17 : memref<1x64x8x8xf32>) value_inputs(%cst_1, %cst_2 : f32, f32) [original_read_memrefs(%alloc_14, %arg0 : memref<1x8x8x64xf32>, memref<1x64x8x8xf32>), original_write_memrefs(%alloc_17 : memref<1x64x8x8xf32>)] : (memref<1x8x8x64xf32>, memref<1x64x8x8xf32>, memref<1x64x8x8xf32>, f32, f32) -> (memref<1x64x8x8xf32>) { +// STREAM-NEXT: ^bb0(%arg1: memref<1x8x8x64xf32>, %arg2: memref<1x64x8x8xf32>, %arg3: memref<1x64x8x8xf32>, %arg4: f32, %arg5: f32): +// STREAM-NEXT: affine.for %arg6 = 0 to 1 { +// STREAM-NEXT: affine.for %arg7 = 0 to 64 { +// STREAM-NEXT: affine.for %arg8 = 0 to 8 { +// STREAM-NEXT: affine.for %arg9 = 0 to 8 { +// STREAM-NEXT: %0 = affine.load %arg1[%arg6, %arg8, %arg9, %arg7] : memref<1x8x8x64xf32> +// STREAM-NEXT: %1 = affine.load %arg2[%arg6, %arg7, %arg8, %arg9] : memref<1x64x8x8xf32> +// STREAM-NEXT: %2 = arith.addf %0, %1 : f32 +// STREAM-NEXT: %3 = arith.minimumf %2, %arg4 : f32 +// STREAM-NEXT: %4 = arith.maximumf %3, %arg5 : f32 +// STREAM-NEXT: affine.store %4, %arg3[%arg6, %arg7, %arg8, %arg9] : memref<1x64x8x8xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: } +// STREAM-NEXT: taskflow.yield writes(%arg3 : memref<1x64x8x8xf32>) +// STREAM-NEXT: } +// STREAM-NEXT: return %write_outputs_18 : memref<1x64x8x8xf32> +// STREAM-NEXT: } +// STREAM-NEXT: } +