From 37b396c52b0703d8b4ccc233461ea95973d2a96b Mon Sep 17 00:00:00 2001 From: tancheng Date: Sun, 15 Jun 2025 06:41:58 +0000 Subject: [PATCH 1/4] Handle block without argument but directly leveraging previous ops as live-ins --- .../TransformCtrlToDataFlowPass.cpp | 78 +++++++++++-------- 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp index f66ddaa1..b07f90af 100644 --- a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp +++ b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp @@ -49,7 +49,7 @@ void createPhiNodesForBlock(Block *block, OpBuilder &builder, break; } } - assert(found_in_block_argument && "Live-in value defined outside the block must be passed as a block argument"); + // assert(found_in_block_argument && "Live-in value defined outside the block must be passed as a block argument"); live_ins.push_back(operand); } @@ -75,49 +75,59 @@ void createPhiNodesForBlock(Block *block, OpBuilder &builder, block->front().getLoc(); SmallVector phi_operands; - - // Finds index of live_in in block arguments. - auto arg = dyn_cast(live_in); - assert(arg && "Expected live_in to be a block argument"); - unsigned arg_index = arg.getArgNumber(); - - for (Block *pred : block->getPredecessors()) { - Value incoming; - Operation *term = pred->getTerminator(); - - // If it's a branch or cond_br, get the value passed into this block argument - if (auto br = dyn_cast(term)) { - auto args = br.getArgs(); - assert(arg_index < args.size()); - incoming = args[arg_index]; - } else if (auto condBr = dyn_cast(term)) { - if (condBr.getTrueDest() == block) { - auto trueArgs = condBr.getTrueArgs(); - assert(arg_index < trueArgs.size()); - incoming = trueArgs[arg_index]; - } else if (condBr.getFalseDest() == block) { - auto falseArgs = condBr.getFalseArgs(); - assert(arg_index < falseArgs.size()); - incoming = falseArgs[arg_index]; + BlockArgument arg = dyn_cast(live_in); + // Handles the case where live_in is not a block argument. + if (!arg) { + phi_operands.push_back(live_in); + } else { + // Finds index of live_in in block arguments. + unsigned arg_index = arg.getArgNumber(); + for (Block *pred : block->getPredecessors()) { + Value incoming; + Operation *term = pred->getTerminator(); + + // If it's a branch or cond_br, get the value passed into this block argument + if (auto br = dyn_cast(term)) { + auto args = br.getArgs(); + assert(arg_index < args.size()); + incoming = args[arg_index]; + } else if (auto condBr = dyn_cast(term)) { + if (condBr.getTrueDest() == block) { + auto trueArgs = condBr.getTrueArgs(); + assert(arg_index < trueArgs.size()); + incoming = trueArgs[arg_index]; + } else if (condBr.getFalseDest() == block) { + auto falseArgs = condBr.getFalseArgs(); + assert(arg_index < falseArgs.size()); + incoming = falseArgs[arg_index]; + } else { + llvm::errs() << "cond_br does not target block:\n" << *block << "\n"; + continue; + } } else { - llvm::errs() << "cond_br does not target block:\n" << *block << "\n"; + llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n"; continue; } - } else { - llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n"; - continue; + phi_operands.push_back(incoming); } - phi_operands.push_back(incoming); } - // Use default value if no incoming values found assert(!phi_operands.empty()); - // Create the phi node with dynamic number of operands + // Creates the phi node with dynamic number of operands. auto phi_op = builder.create(loc, predicated_type, phi_operands); - // Replace block argument use with the phi result - arg.replaceAllUsesWith(phi_op.getResult()); + // Saves users to be replaced *after* phi is constructed. + SmallVector uses_to_be_replaced; + for (OpOperand &use : live_in.getUses()) { + if (use.getOwner() != phi_op) { + uses_to_be_replaced.push_back(&use); + } + } + // Replaces block argument use with the phi result. + for (OpOperand *use : uses_to_be_replaced) { + use->set(phi_op.getResult()); + } value_map[live_in] = phi_op.getResult(); } } From f58a132f3e41b62a90afcf987a5f9ce689c44c5e Mon Sep 17 00:00:00 2001 From: tancheng Date: Sun, 15 Jun 2025 06:44:07 +0000 Subject: [PATCH 2/4] [test] Include test --- test/neura/ctrl/branch_no_arg.mlir | 68 ++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 test/neura/ctrl/branch_no_arg.mlir diff --git a/test/neura/ctrl/branch_no_arg.mlir b/test/neura/ctrl/branch_no_arg.mlir new file mode 100644 index 00000000..7ed73175 --- /dev/null +++ b/test/neura/ctrl/branch_no_arg.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-neura-opt %s \ +// RUN: --assign-accelerator \ +// RUN: --lower-llvm-to-neura \ +// RUN: --leverage-predicated-value \ +// RUN: | FileCheck %s + +// RUN: mlir-neura-opt %s \ +// RUN: --assign-accelerator \ +// RUN: --lower-llvm-to-neura \ +// RUN: --leverage-predicated-value \ +// RUN: --transform-ctrl-to-data-flow \ +// RUN: | FileCheck %s -check-prefix=CTRL2DATA + +func.func @test(%in: i64) -> f32 { + %c0 = llvm.mlir.constant(0 : i64) : i64 + %c1 = llvm.mlir.constant(1.0 : f32) : f32 + %c2 = llvm.mlir.constant(2.0 : f32) : f32 + %c3 = llvm.mlir.constant(3.0 : f32) : f32 + %c4 = llvm.mlir.constant(4.0 : f32) : f32 + %cond = llvm.icmp "eq" %in, %c0 : i64 + llvm.cond_br %cond, ^bb2(%c3, %c4 : f32, f32), ^bb1 + +^bb1: + %a = llvm.fadd %c1, %c2 : f32 + llvm.br ^bb3(%a : f32) + +^bb2(%cc: f32, %cd: f32): + %b = llvm.fmul %cc, %cd : f32 + llvm.br ^bb3(%b : f32) + +^bb3(%v: f32): + return %v : f32 +} + +// CHECK: func.func @test(%arg0: i64) -> f32 attributes {accelerator = "neura"} { +// CHECK-NEXT: %0 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> !neura.data +// CHECK-NEXT: %1 = "neura.constant"() <{predicate = true, value = 1.000000e+00 : f32}> : () -> !neura.data +// CHECK-NEXT: %2 = "neura.constant"() <{predicate = true, value = 2.000000e+00 : f32}> : () -> !neura.data +// CHECK-NEXT: %3 = "neura.constant"() <{predicate = true, value = 3.000000e+00 : f32}> : () -> !neura.data +// CHECK-NEXT: %4 = "neura.constant"() <{predicate = true, value = 4.000000e+00 : f32}> : () -> !neura.data +// CHECK-NEXT: %5 = "neura.icmp"(%arg0, %0) <{cmpType = "eq"}> : (i64, !neura.data) -> !neura.data +// CHECK-NEXT: neura.cond_br %5 : !neura.data then %3, %4 : !neura.data, !neura.data to ^bb2 else : to ^bb1 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %6 = "neura.fadd"(%1, %2) : (!neura.data, !neura.data) -> !neura.data +// CHECK-NEXT: neura.br %6 : !neura.data to ^bb3 +// CHECK-NEXT: ^bb2(%7: !neura.data, %8: !neura.data): // pred: ^bb0 +// CHECK-NEXT: %9 = "neura.fmul"(%7, %8) : (!neura.data, !neura.data) -> !neura.data +// CHECK-NEXT: neura.br %9 : !neura.data to ^bb3 +// CHECK-NEXT: ^bb3(%10: !neura.data): // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: "neura.return"(%10) : (!neura.data) -> () +// CHECK-NEXT: } + +// CTRL2DATA: func.func @test(%arg0: i64) -> f32 attributes {accelerator = "neura"} { +// CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> !neura.data +// CTRL2DATA-NEXT: %1 = "neura.constant"() <{predicate = true, value = 1.000000e+00 : f32}> : () -> !neura.data +// CTRL2DATA-NEXT: %2 = "neura.constant"() <{predicate = true, value = 2.000000e+00 : f32}> : () -> !neura.data +// CTRL2DATA-NEXT: %3 = "neura.constant"() <{predicate = true, value = 3.000000e+00 : f32}> : () -> !neura.data +// CTRL2DATA-NEXT: %4 = "neura.constant"() <{predicate = true, value = 4.000000e+00 : f32}> : () -> !neura.data +// CTRL2DATA-NEXT: %5 = "neura.icmp"(%arg0, %0) <{cmpType = "eq"}> : (i64, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: %6 = "neura.phi"(%1) : (!neura.data) -> !neura.data +// CTRL2DATA-NEXT: %7 = "neura.phi"(%2) : (!neura.data) -> !neura.data +// CTRL2DATA-NEXT: %8 = "neura.fadd"(%6, %7) : (!neura.data, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: %9 = "neura.phi"(%3) : (!neura.data) -> !neura.data +// CTRL2DATA-NEXT: %10 = "neura.phi"(%4) : (!neura.data) -> !neura.data +// CTRL2DATA-NEXT: %11 = "neura.fmul"(%9, %10) : (!neura.data, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: %12 = "neura.phi"(%8, %11) : (!neura.data, !neura.data) -> !neura.data +// CTRL2DATA-NEXT: "neura.return"(%12) : (!neura.data) -> () +// CTRL2DATA-NEXT: } \ No newline at end of file From 6151ea7a54280210ef9d23b816bfc1c2c3521a71 Mon Sep 17 00:00:00 2001 From: tancheng Date: Sun, 15 Jun 2025 06:45:17 +0000 Subject: [PATCH 3/4] [cleanup] Format and comment --- lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp | 1 - test/neura/ctrl/branch_no_arg.mlir | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp index b07f90af..e6fc78c4 100644 --- a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp +++ b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp @@ -49,7 +49,6 @@ void createPhiNodesForBlock(Block *block, OpBuilder &builder, break; } } - // assert(found_in_block_argument && "Live-in value defined outside the block must be passed as a block argument"); live_ins.push_back(operand); } diff --git a/test/neura/ctrl/branch_no_arg.mlir b/test/neura/ctrl/branch_no_arg.mlir index 7ed73175..26165db5 100644 --- a/test/neura/ctrl/branch_no_arg.mlir +++ b/test/neura/ctrl/branch_no_arg.mlir @@ -65,4 +65,4 @@ func.func @test(%in: i64) -> f32 { // CTRL2DATA-NEXT: %11 = "neura.fmul"(%9, %10) : (!neura.data, !neura.data) -> !neura.data // CTRL2DATA-NEXT: %12 = "neura.phi"(%8, %11) : (!neura.data, !neura.data) -> !neura.data // CTRL2DATA-NEXT: "neura.return"(%12) : (!neura.data) -> () -// CTRL2DATA-NEXT: } \ No newline at end of file +// CTRL2DATA-NEXT: } From 2f9ac7e64fb82ff050072fc42920a501f460e375 Mon Sep 17 00:00:00 2001 From: tancheng Date: Sun, 15 Jun 2025 06:47:46 +0000 Subject: [PATCH 4/4] [cleanup] Update comments --- lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp index e6fc78c4..0db4c778 100644 --- a/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp +++ b/lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp @@ -123,7 +123,7 @@ void createPhiNodesForBlock(Block *block, OpBuilder &builder, uses_to_be_replaced.push_back(&use); } } - // Replaces block argument use with the phi result. + // Replaces live-in uses with the phi result. for (OpOperand *use : uses_to_be_replaced) { use->set(phi_op.getResult()); }