Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions lib/NeuraDialect/Transforms/TransformCtrlToDataFlowPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void createPhiNodesForBlock(Block *block, OpBuilder &builder,
break;
}
}
assert(found_in_block_argument && "Live-in value defined outside the block must be passed as a block argument");
live_ins.push_back(operand);
}

Expand All @@ -75,49 +74,59 @@ void createPhiNodesForBlock(Block *block, OpBuilder &builder,
block->front().getLoc();

SmallVector<Value> phi_operands;

// Finds index of live_in in block arguments.
auto arg = dyn_cast<BlockArgument>(live_in);
assert(arg && "Expected live_in to be a block argument");
unsigned arg_index = arg.getArgNumber();

for (Block *pred : block->getPredecessors()) {
Value incoming;
Operation *term = pred->getTerminator();

// If it's a branch or cond_br, get the value passed into this block argument
if (auto br = dyn_cast<neura::Br>(term)) {
auto args = br.getArgs();
assert(arg_index < args.size());
incoming = args[arg_index];
} else if (auto condBr = dyn_cast<neura::CondBr>(term)) {
if (condBr.getTrueDest() == block) {
auto trueArgs = condBr.getTrueArgs();
assert(arg_index < trueArgs.size());
incoming = trueArgs[arg_index];
} else if (condBr.getFalseDest() == block) {
auto falseArgs = condBr.getFalseArgs();
assert(arg_index < falseArgs.size());
incoming = falseArgs[arg_index];
BlockArgument arg = dyn_cast<BlockArgument>(live_in);
// Handles the case where live_in is not a block argument.
if (!arg) {
phi_operands.push_back(live_in);
} else {
// Finds index of live_in in block arguments.
unsigned arg_index = arg.getArgNumber();
for (Block *pred : block->getPredecessors()) {
Value incoming;
Operation *term = pred->getTerminator();

// If it's a branch or cond_br, get the value passed into this block argument
if (auto br = dyn_cast<neura::Br>(term)) {
auto args = br.getArgs();
assert(arg_index < args.size());
incoming = args[arg_index];
} else if (auto condBr = dyn_cast<neura::CondBr>(term)) {
if (condBr.getTrueDest() == block) {
auto trueArgs = condBr.getTrueArgs();
assert(arg_index < trueArgs.size());
incoming = trueArgs[arg_index];
} else if (condBr.getFalseDest() == block) {
auto falseArgs = condBr.getFalseArgs();
assert(arg_index < falseArgs.size());
incoming = falseArgs[arg_index];
} else {
llvm::errs() << "cond_br does not target block:\n" << *block << "\n";
continue;
}
} else {
llvm::errs() << "cond_br does not target block:\n" << *block << "\n";
llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n";
continue;
}
} else {
llvm::errs() << "Unknown branch terminator in block: " << *pred << "\n";
continue;
phi_operands.push_back(incoming);
}
phi_operands.push_back(incoming);
}

// Use default value if no incoming values found
assert(!phi_operands.empty());

// Create the phi node with dynamic number of operands
// Creates the phi node with dynamic number of operands.
auto phi_op = builder.create<neura::PhiOp>(loc, predicated_type, phi_operands);

// Replace block argument use with the phi result
arg.replaceAllUsesWith(phi_op.getResult());
// Saves users to be replaced *after* phi is constructed.
SmallVector<OpOperand *> uses_to_be_replaced;
for (OpOperand &use : live_in.getUses()) {
if (use.getOwner() != phi_op) {
uses_to_be_replaced.push_back(&use);
}
}
// Replaces live-in uses with the phi result.
for (OpOperand *use : uses_to_be_replaced) {
use->set(phi_op.getResult());
}
value_map[live_in] = phi_op.getResult();
}
}
Expand Down
68 changes: 68 additions & 0 deletions test/neura/ctrl/branch_no_arg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: mlir-neura-opt %s \
// RUN: --assign-accelerator \
// RUN: --lower-llvm-to-neura \
// RUN: --leverage-predicated-value \
// RUN: | FileCheck %s

// RUN: mlir-neura-opt %s \
// RUN: --assign-accelerator \
// RUN: --lower-llvm-to-neura \
// RUN: --leverage-predicated-value \
// RUN: --transform-ctrl-to-data-flow \
// RUN: | FileCheck %s -check-prefix=CTRL2DATA

func.func @test(%in: i64) -> f32 {
%c0 = llvm.mlir.constant(0 : i64) : i64
%c1 = llvm.mlir.constant(1.0 : f32) : f32
%c2 = llvm.mlir.constant(2.0 : f32) : f32
%c3 = llvm.mlir.constant(3.0 : f32) : f32
%c4 = llvm.mlir.constant(4.0 : f32) : f32
%cond = llvm.icmp "eq" %in, %c0 : i64
llvm.cond_br %cond, ^bb2(%c3, %c4 : f32, f32), ^bb1

^bb1:
%a = llvm.fadd %c1, %c2 : f32
llvm.br ^bb3(%a : f32)

^bb2(%cc: f32, %cd: f32):
%b = llvm.fmul %cc, %cd : f32
llvm.br ^bb3(%b : f32)

^bb3(%v: f32):
return %v : f32
}

// CHECK: func.func @test(%arg0: i64) -> f32 attributes {accelerator = "neura"} {
// CHECK-NEXT: %0 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> !neura.data<i64, i1>
// CHECK-NEXT: %1 = "neura.constant"() <{predicate = true, value = 1.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CHECK-NEXT: %2 = "neura.constant"() <{predicate = true, value = 2.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CHECK-NEXT: %3 = "neura.constant"() <{predicate = true, value = 3.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CHECK-NEXT: %4 = "neura.constant"() <{predicate = true, value = 4.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CHECK-NEXT: %5 = "neura.icmp"(%arg0, %0) <{cmpType = "eq"}> : (i64, !neura.data<i64, i1>) -> !neura.data<i1, i1>
// CHECK-NEXT: neura.cond_br %5 : !neura.data<i1, i1> then %3, %4 : !neura.data<f32, i1>, !neura.data<f32, i1> to ^bb2 else : to ^bb1
// CHECK-NEXT: ^bb1: // pred: ^bb0
// CHECK-NEXT: %6 = "neura.fadd"(%1, %2) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CHECK-NEXT: neura.br %6 : !neura.data<f32, i1> to ^bb3
// CHECK-NEXT: ^bb2(%7: !neura.data<f32, i1>, %8: !neura.data<f32, i1>): // pred: ^bb0
// CHECK-NEXT: %9 = "neura.fmul"(%7, %8) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CHECK-NEXT: neura.br %9 : !neura.data<f32, i1> to ^bb3
// CHECK-NEXT: ^bb3(%10: !neura.data<f32, i1>): // 2 preds: ^bb1, ^bb2
// CHECK-NEXT: "neura.return"(%10) : (!neura.data<f32, i1>) -> ()
// CHECK-NEXT: }

// CTRL2DATA: func.func @test(%arg0: i64) -> f32 attributes {accelerator = "neura"} {
// CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> !neura.data<i64, i1>
// CTRL2DATA-NEXT: %1 = "neura.constant"() <{predicate = true, value = 1.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %2 = "neura.constant"() <{predicate = true, value = 2.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %3 = "neura.constant"() <{predicate = true, value = 3.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %4 = "neura.constant"() <{predicate = true, value = 4.000000e+00 : f32}> : () -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %5 = "neura.icmp"(%arg0, %0) <{cmpType = "eq"}> : (i64, !neura.data<i64, i1>) -> !neura.data<i1, i1>
// CTRL2DATA-NEXT: %6 = "neura.phi"(%1) : (!neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %7 = "neura.phi"(%2) : (!neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %8 = "neura.fadd"(%6, %7) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %9 = "neura.phi"(%3) : (!neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %10 = "neura.phi"(%4) : (!neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %11 = "neura.fmul"(%9, %10) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: %12 = "neura.phi"(%8, %11) : (!neura.data<f32, i1>, !neura.data<f32, i1>) -> !neura.data<f32, i1>
// CTRL2DATA-NEXT: "neura.return"(%12) : (!neura.data<f32, i1>) -> ()
// CTRL2DATA-NEXT: }