Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore

This file was deleted.

53 changes: 51 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def Neura_SubOp : Op<NeuraDialect, "sub"> {
let traits = [SameOperandsAndResultElementType];
}

def Neura_MulOp : Op<NeuraDialect, "mul"> {
let summary = "Integer multiplication operation";
let opName = "mul";
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

def Neura_DivOp : Op<NeuraDialect, "div"> {
let summary = "Integer division operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

// Defines a floating-point addition operation.
def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
let summary = "Floating addition operation";
Expand Down Expand Up @@ -147,7 +164,7 @@ def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed", [AttrSizedOperandSeg
// Defines a pointer computation operation.
def Neura_GEP : Op<NeuraDialect, "gep"> {
let summary = "Pointer computation using offset indices";
let arguments = (ins AnyType:$base, Variadic<AnyInteger>:$indicesAndPredicate);
let arguments = (ins AnyType:$base, Variadic<AnyType>:$indicesAndPredicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$base `[` $indicesAndPredicate `]` `,` $predicate attr-dict";
}
Expand All @@ -170,7 +187,7 @@ def Neura_Br : Op<NeuraDialect, "br", [Terminator]> {
}

def Neura_SelOp : Op<NeuraDialect, "sel"> {
let arguments = (ins AnyType:$ifTrue, AnyType:$ifFalse, I1:$cond);
let arguments = (ins AnyType:$ifTrue, AnyType:$ifFalse, AnyType:$cond);
let results = (outs AnyType:$result);
// let assemblyFormat = "$ifTrue `,` $ifFalse `,` $cond attr-dict `:` type($ifTrue)";
}
Expand Down Expand Up @@ -351,4 +368,36 @@ def Neura_GrantAlwaysOp : Op<NeuraDialect, "grant_always"> {
let results = (outs AnyType:$result);

// let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)";
}

// ----------------------------------------------------
// Defines fused control flow operations.

def Neura_LoopControllerOp : Op<NeuraDialect, "loop_controller">{
let summary = "Generates loop indicies and valid predicates.";
let description = [{
Manages a single level of loop execution based on cycle counting.
Each loop_controller outputs a current index value and a valid predicate.

The loop_controller uses dynamic loop bounds (start, end, step),
allowing for variable-length loops and runtime-determined bounds.

The execution is conditioned on the parent_valid input, creating an
efficient hierarchical structure for nested loops.
}];

let arguments = (ins
AnyType:$parent_valid, // Valid predicate from the parent loop
AnyType:$start, // Start index of the loop
AnyType:$end, // End index of the loop
AnyType:$step // Step size for the loop
);

let results = (outs
AnyType:$index, // Current loop index
AnyType:$valid // Valid predicate for the current index
);

let assemblyFormat =
"$parent_valid `(` $start `,` $end `,` $step `)` attr-dict `:` type($parent_valid) `,` type($start) `,` type($end) `,` type($step) `->` type($index) `,` type($valid)";
}
1 change: 1 addition & 0 deletions include/NeuraDialect/NeuraPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::unique_ptr<mlir::Pass> createTransformCtrlToDataFlowPass();
std::unique_ptr<mlir::Pass> createLeveragePredicatedValuePass();
std::unique_ptr<mlir::Pass> createMapToAcceleratorPass();
std::unique_ptr<mlir::Pass> createGenerateCodePass();
std::unique_ptr<mlir::Pass> createFuseControlFlowPass();

#define GEN_PASS_REGISTRATION
#include "NeuraDialect/NeuraPasses.h.inc"
Expand Down
8 changes: 8 additions & 0 deletions include/NeuraDialect/NeuraPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,12 @@ def GenerateCode : Pass<"generate-code", "ModuleOp"> {
let constructor = "neura::createGenerateCodePass()";
}

def FuseControlFlow: Pass<"fuse-control-flow", "ModuleOp">{
let summary = "Fuses control flow operations in the Neura dialect";
let description = [{
This pass fuses control flow operations.
}];
let constructor = "neura::createFuseControlFlowPass()";
}

#endif // NEURA_PASSES_TD
72 changes: 64 additions & 8 deletions lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
}
};

struct ArithMulIToNeuraMul : public OpRewritePattern<mlir::arith::MulIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MulIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::MulOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -124,6 +140,21 @@ struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
}
};

struct ArithDivSIToNeuraDiv : public OpRewritePattern<mlir::arith::DivSIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::DivSIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();
// Converts arith DivSIOp to Neura DivOp.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::DivOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -139,6 +170,30 @@ struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
return success();
}
};

struct ArithRemSIToNeuraOp : public OpRewritePattern<mlir::arith::RemSIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::RemSIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();
Location loc = op.getLoc();
// Converts arith RemSIOp to basic Neura Op.
// Optional predicate: default to null.
Value div =
rewriter.create<neura::DivOp>(loc, result_type, lhs, rhs, nullptr);
Value mul =
rewriter.create<neura::MulOp>(loc, result_type, rhs, div, nullptr);
Value rem =
rewriter.create<neura::SubOp>(loc, result_type, lhs, mul, nullptr);

rewriter.replaceOp(op, rem);
return success();
}
};

struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -252,8 +307,8 @@ struct ArithIndexCastToNeuraCast
Type in_type = input.getType();
StringRef cast_string;

// The isa<IntegerType> check is generic and handles any integer bit width.
// (e.g., i32, i64).
// The isa<IntegerType> check is generic and handles any integer bit
// width (e.g., i32, i64).
if (in_type.isIndex() && isa<IntegerType>(result_type)) {
cast_string = "index_to_int";
} else if (isa<IntegerType>(in_type) && result_type.isIndex()) {
Expand Down Expand Up @@ -294,12 +349,13 @@ struct LowerArithToNeuraPass
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp,
ArithSelectToNeuraSel, ArithExtUIToNeuraCast,
ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv,
ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(context);
patterns.add<
ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
ArithSubIToNeuraSub, ArithSubFToNeuraFSub, ArithMulIToNeuraMul,
ArithDivSIToNeuraDiv, ArithRemSIToNeuraOp>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
Expand Down
Loading