Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore

This file was deleted.

21 changes: 19 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def Neura_SubOp : Op<NeuraDialect, "sub"> {
let traits = [SameOperandsAndResultElementType];
}

def Neura_MulOp : Op<NeuraDialect, "mul"> {
let summary = "Integer multiplication operation";
let opName = "mul";
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

def Neura_DivOp : Op<NeuraDialect, "div"> {
let summary = "Integer division operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

// Defines a floating-point addition operation.
def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
let summary = "Floating addition operation";
Expand Down Expand Up @@ -147,7 +164,7 @@ def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed", [AttrSizedOperandSeg
// Defines a pointer computation operation.
def Neura_GEP : Op<NeuraDialect, "gep"> {
let summary = "Pointer computation using offset indices";
let arguments = (ins AnyType:$base, Variadic<AnyInteger>:$indicesAndPredicate);
let arguments = (ins AnyType:$base, Variadic<AnyType>:$indicesAndPredicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$base `[` $indicesAndPredicate `]` `,` $predicate attr-dict";
}
Expand All @@ -170,7 +187,7 @@ def Neura_Br : Op<NeuraDialect, "br", [Terminator]> {
}

def Neura_SelOp : Op<NeuraDialect, "sel"> {
let arguments = (ins AnyType:$ifTrue, AnyType:$ifFalse, I1:$cond);
let arguments = (ins AnyType:$ifTrue, AnyType:$ifFalse, AnyType:$cond);
let results = (outs AnyType:$result);
// let assemblyFormat = "$ifTrue `,` $ifFalse `,` $cond attr-dict `:` type($ifTrue)";
}
Expand Down
1 change: 1 addition & 0 deletions include/NeuraDialect/NeuraPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ std::unique_ptr<mlir::Pass> createTransformCtrlToDataFlowPass();
std::unique_ptr<mlir::Pass> createLeveragePredicatedValuePass();
std::unique_ptr<mlir::Pass> createMapToAcceleratorPass();
std::unique_ptr<mlir::Pass> createGenerateCodePass();
std::unique_ptr<mlir::Pass> createFuseControlFlowPass();

#define GEN_PASS_REGISTRATION
#include "NeuraDialect/NeuraPasses.h.inc"
Expand Down
8 changes: 8 additions & 0 deletions include/NeuraDialect/NeuraPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,12 @@ def GenerateCode : Pass<"generate-code", "ModuleOp"> {
let constructor = "neura::createGenerateCodePass()";
}

def FuseControlFlow: Pass<"fuse-control-flow", "ModuleOp">{
let summary = "Fuses control flow operations in the Neura dialect";
let description = [{
This pass fuses control flow operations to optimize the Neura dialect.
}];
let constructor = "neura::createFuseControlFlowPass()";
}

#endif // NEURA_PASSES_TD
72 changes: 64 additions & 8 deletions lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
}
};

struct ArithMulIToNeuraMul : public OpRewritePattern<mlir::arith::MulIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MulIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::MulOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -124,6 +140,21 @@ struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
}
};

struct ArithDivSIToNeuraDiv : public OpRewritePattern<mlir::arith::DivSIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::DivSIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();
// Converts arith DivSIOp to Neura DivOp.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::DivOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -139,6 +170,30 @@ struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
return success();
}
};

struct ArithRemSIToNeuraOp : public OpRewritePattern<mlir::arith::RemSIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::RemSIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();
Location loc = op.getLoc();
// Converts arith RemSIOp to basic Neura Op.
// Optional predicate: default to null.
Value div =
rewriter.create<neura::DivOp>(loc, result_type, lhs, rhs, nullptr);
Value mul =
rewriter.create<neura::MulOp>(loc, result_type, rhs, div, nullptr);
Value rem =
rewriter.create<neura::SubOp>(loc, result_type, lhs, mul, nullptr);

rewriter.replaceOp(op, rem);
return success();
}
};

struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -252,8 +307,8 @@ struct ArithIndexCastToNeuraCast
Type in_type = input.getType();
StringRef cast_string;

// The isa<IntegerType> check is generic and handles any integer bit width.
// (e.g., i32, i64).
// The isa<IntegerType> check is generic and handles any integer bit
// width. (e.g., i32, i64).
if (in_type.isIndex() && isa<IntegerType>(result_type)) {
cast_string = "index_to_int";
} else if (isa<IntegerType>(in_type) && result_type.isIndex()) {
Expand Down Expand Up @@ -294,12 +349,13 @@ struct LowerArithToNeuraPass
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp,
ArithSelectToNeuraSel, ArithExtUIToNeuraCast,
ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv,
ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(context);
patterns.add<
ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
ArithSubIToNeuraSub, ArithSubFToNeuraFSub, ArithMulIToNeuraMul,
ArithDivSIToNeuraDiv, ArithRemSIToNeuraOp>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
Expand Down
48 changes: 28 additions & 20 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include "Common/AcceleratorAttrs.h"
#include "Conversion/ConversionPasses.h"
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "Conversion/ConversionPasses.h"
#include "llvm/Support/raw_ostream.h"

namespace mlir {
Expand All @@ -35,7 +35,8 @@ struct LlvmAddToNeuraAdd : public OpRewritePattern<mlir::LLVM::AddOp> {

LogicalResult matchAndRewrite(mlir::LLVM::AddOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(), op.getRhs(), Value());
rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(),
op.getRhs(), Value());
return success();
}
};
Expand All @@ -54,7 +55,8 @@ struct LlvmFAddToNeuraFAdd : public OpRewritePattern<mlir::LLVM::FAddOp> {
return failure();

// Optional predicate: default to 'none'
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs, Value());
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs,
Value());
return success();
}
};
Expand All @@ -69,12 +71,13 @@ struct LlvmFSubToNeuraFSub : public OpRewritePattern<mlir::LLVM::FSubOp> {
Type result_type = op->getResult(0).getType();

// Only matches scalar float.
if (!mlir::isa<FloatType>(result_type)){
if (!mlir::isa<FloatType>(result_type)) {
return failure();
}

// Optional predicate: default to 'none'
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs, Value());
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
Value());
return success();
}
};
Expand All @@ -84,7 +87,8 @@ struct LlvmOrToNeuraOr : public OpRewritePattern<mlir::LLVM::OrOp> {

LogicalResult matchAndRewrite(mlir::LLVM::OrOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<neura::OrOp>(op, op.getType(), op.getLhs(), op.getRhs(), Value());
rewriter.replaceOpWithNewOp<neura::OrOp>(op, op.getType(), op.getLhs(),
op.getRhs(), Value());
return success();
}
};
Expand All @@ -102,12 +106,13 @@ struct LlvmFMulToNeuraFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
if (!mlir::isa<FloatType>(result_type))
return failure();

rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs, Value());
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs,
Value());
return success();
}
};

struct LlvmVFMulToNeuraVFMul: public OpRewritePattern<mlir::LLVM::FMulOp> {
struct LlvmVFMulToNeuraVFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::LLVM::FMulOp op,
Expand All @@ -121,7 +126,8 @@ struct LlvmVFMulToNeuraVFMul: public OpRewritePattern<mlir::LLVM::FMulOp> {
if (!vecTy || !mlir::isa<FloatType>(vecTy.getElementType()))
return failure();

rewriter.replaceOpWithNewOp<neura::VFMulOp>(op, result_type, lhs, rhs, Value());
rewriter.replaceOpWithNewOp<neura::VFMulOp>(op, result_type, lhs, rhs,
Value());
return success();
}
};
Expand Down Expand Up @@ -173,7 +179,8 @@ struct LlvmGEPToNeuraGEP : public OpRewritePattern<mlir::LLVM::GEPOp> {
indexValues.push_back(val);
} else if (auto intAttr = gepIndex.dyn_cast<IntegerAttr>()) {
// Create constant operation state manually
OperationState state(op.getLoc(), neura::ConstantOp::getOperationName());
OperationState state(op.getLoc(),
neura::ConstantOp::getOperationName());
state.addAttribute("value", intAttr);
state.addAttribute("predicate", rewriter.getBoolAttr(true));
state.addTypes(rewriter.getIndexType());
Expand All @@ -184,7 +191,8 @@ struct LlvmGEPToNeuraGEP : public OpRewritePattern<mlir::LLVM::GEPOp> {
}
}

rewriter.replaceOpWithNewOp<neura::GEP>(op, op.getType(), base, indexValues);
rewriter.replaceOpWithNewOp<neura::GEP>(op, op.getType(), base,
indexValues);
return success();
}
};
Expand All @@ -194,7 +202,7 @@ struct LlvmLoadToNeuraLoad : public OpRewritePattern<mlir::LLVM::LoadOp> {

LogicalResult matchAndRewrite(mlir::LLVM::LoadOp op,
PatternRewriter &rewriter) const override {
Value ptr = op.getAddr(); // getPointer() is deprecated
Value ptr = op.getAddr(); // getPointer() is deprecated
Type resultType = op.getResult().getType();
rewriter.replaceOpWithNewOp<neura::LoadOp>(op, resultType, ptr, Value());
return success();
Expand All @@ -207,7 +215,7 @@ struct LlvmStoreToNeuraStore : public OpRewritePattern<mlir::LLVM::StoreOp> {
LogicalResult matchAndRewrite(mlir::LLVM::StoreOp op,
PatternRewriter &rewriter) const override {
Value value = op.getValue();
Value addr = op.getAddr(); // getPointer() is deprecated
Value addr = op.getAddr(); // getPointer() is deprecated
rewriter.replaceOpWithNewOp<neura::StoreOp>(op, value, addr, Value());
return success();
}
Expand Down Expand Up @@ -253,8 +261,7 @@ struct LlvmBrToNeuraBr : public OpRewritePattern<LLVM::BrOp> {
ValueRange destOperands = op.getDestOperands();

// Create the new Neura_Br operation
rewriter.replaceOpWithNewOp<neura::Br>(
op, destOperands, dest);
rewriter.replaceOpWithNewOp<neura::Br>(op, destOperands, dest);

return success();
}
Expand Down Expand Up @@ -284,15 +291,15 @@ struct LlvmConstantToNeuraConstant : public OpRewritePattern<LLVM::ConstantOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::ConstantOp op,
PatternRewriter &rewriter) const override {
PatternRewriter &rewriter) const override {
auto attr = op.getValue();

// Create operation state manually
OperationState state(op.getLoc(), neura::ConstantOp::getOperationName());
state.addAttribute("value", attr);
state.addAttribute("predicate", rewriter.getBoolAttr(true));
state.addTypes(op.getType());

// Create the operation and replace
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
Expand Down Expand Up @@ -343,7 +350,8 @@ struct LowerLlvmToNeuraPass
// e.g., mlir func or llvm func).
module_op.walk([&](FunctionOpInterface func) {
if (func->hasAttr(mlir::accel::kAcceleratorAttr)) {
auto target = func->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
auto target =
func->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
for (Region &region : func->getRegions()) {
if (failed(applyPatternsGreedily(region, frozen))) {
Expand Down
1 change: 1 addition & 0 deletions lib/NeuraDialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_mlir_library(
LeveragePredicatedValuePass.cpp
MapToAcceleratorPass.cpp
GenerateCodePass.cpp
FuseControlFlowPass.cpp

DEPENDS
MLIRNeuraTransformsIncGen
Expand Down
40 changes: 40 additions & 0 deletions lib/NeuraDialect/Transforms/FuseControlFlowPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

#define GEN_PASS_DEF_FUSECONTROLFLOW
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {
// A class to hold loop information for the control flow fusion pass.
class LoopInfo {
public:
// TODO: Adds necessary fields and methods to store loop information.
};

struct FuseControlFlowPass
: public PassWrapper<FuseControlFlowPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseControlFlowPass)

StringRef getArgument() const override { return "fuse-control-flow"; }
StringRef getDescription() const override {
return "Fuses control flow operations into optimized neura dialect "
"operations";
}

void runOnOperation() override {
ModuleOp module_op = getOperation();
// TODO: Adds the logic to fuse determined control flow operations.
}
};
} // namespace

namespace mlir::neura {
std::unique_ptr<Pass> createFuseControlFlowPass() {
return std::make_unique<FuseControlFlowPass>();
}
} // namespace mlir::neura
Loading