Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ def Neura_PhiOp : Op<NeuraDialect, "phi"> {
neura.ctrl_mov %next to %v // Connect next iteration
}];

let arguments = (ins AnyType:$init_val, AnyType:$loop_val);
let arguments = (ins Variadic<AnyType>:$inputs);
let results = (outs AnyType:$result);

// Explicitly specify types for operands in the assembly format
let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `:` type($result)";
// let assemblyFormat = "$init_val `:` type($init_val) `,` $loop_val `:` type($loop_val) attr-dict `,` type($result)";
}

// Control movement extending base move but with different signature.
Expand Down
36 changes: 24 additions & 12 deletions lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,25 @@ struct applyPredicatedDataType : public RewritePattern {
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
llvm::errs() << "Processing op: " << *op << "\n";

// Skips if not a Neura op or already using predicated values.
if (op->getDialect()->getNamespace() != "neura") {
llvm::errs() << "Skipping non-Neura op\n";
return failure();
}

if (llvm::any_of(op->getResultTypes(),
[](Type t) { return mlir::isa<mlir::neura::PredicatedValue>(t); })) {
llvm::errs() << "Skipping already predicated op\n";
return failure();
}

// Converts result types to predicated form.
SmallVector<Type> newResults;
for (Type t : op->getResultTypes()) {
auto predicatedTy = mlir::neura::PredicatedValue::get(
auto predicated_type = mlir::neura::PredicatedValue::get(
op->getContext(),
t,
rewriter.getI1Type());
newResults.push_back(predicatedTy);
newResults.push_back(predicated_type);
}

// Clones the operation with new result types.
Expand All @@ -51,7 +48,6 @@ struct applyPredicatedDataType : public RewritePattern {

// Replaces the old op with the new one.
rewriter.replaceOp(op, newOp->getResults());
llvm::errs() << "Converted op to predicated form: " << *newOp << "\n";
if (!newResults.empty()) {
assert(false);
}
Expand All @@ -77,7 +73,26 @@ struct LeveragePredicatedValuePass

// Processes each function.
module.walk([&](func::FuncOp func) {
// Get operations in topological order (operands before users)
// Converts block argument types to predicated values.
func.walk([&](Block *block) {
// skips the entry (first) block of the function.
if (block == &block->getParent()->front()) {
return;
}
for (BlockArgument arg : block->getArguments()) {
Type origType = arg.getType();

// Avoid double-wrapping if already predicated
if (llvm::isa<neura::PredicatedValue>(origType))
continue;

auto predicated_type = neura::PredicatedValue::get(
func.getContext(), origType, IntegerType::get(func.getContext(), 1));
arg.setType(predicated_type);
}
});

// Gets operations in topological order (operands before users).
SmallVector<Operation*> orderedOps;
getOperationsInTopologicalOrder(func, orderedOps);

Expand Down Expand Up @@ -122,11 +137,8 @@ struct LeveragePredicatedValuePass

// Converts a single operation to use predicated values.
LogicalResult applyPredicatedDataType(Operation *op) {
llvm::errs() << "Processing op: " << *op << "\n";

// Skips if not a Neura op.
if (op->getDialect()->getNamespace() != "neura") {
llvm::errs() << "Skipping non-Neura op\n";
return success();
}

Expand All @@ -141,11 +153,11 @@ struct LeveragePredicatedValuePass
OpBuilder builder(op);
SmallVector<Type> newResults;
for (Type t : op->getResultTypes()) {
auto predicatedTy = mlir::neura::PredicatedValue::get(
auto predicated_type = mlir::neura::PredicatedValue::get(
op->getContext(),
t,
builder.getI1Type());
newResults.push_back(predicatedTy);
newResults.push_back(predicated_type);
}

// Clones with new result types.
Expand Down
Loading