Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/Conversion/ConversionPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace mlir {
std::unique_ptr<mlir::Pass> createLowerArithToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerLlvmToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerMemRefToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerBuiltinToNeuraPass();

#define GEN_PASS_REGISTRATION
#include "Conversion/ConversionPasses.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions include/Conversion/ConversionPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,10 @@ def LowerMemRefToNeura : Pass<"lower-memref-to-neura", "ModuleOp">{
let constructor = "mlir::createLowerMemRefToNeuraPass()";
}

def LowerBuiltinToNeura : Pass<"lower-builtin-to-neura", "ModuleOp">{
let summary = "Lower Builtin to Neura dialect";
let description = [{Lower Builtin operations to Neura dialect operations.}];
let constructor = "mlir::createLowerBuiltinToNeuraPass()";
}

#endif // CONVERSION_PASSES_TD
32 changes: 30 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,34 @@ def Neura_StoreOp : Op<NeuraDialect, "store"> {
// let assemblyFormat = "$value `,` $addr `,` $predicate attr-dict";
}

// Defines a load operation with integrated address calculation.
def Neura_LoadIndexedOp: Op<NeuraDialect, "load_indexed", [AttrSizedOperandSegments]>{
let summary = "Load with integrated address calculation for multi-dimensional arrays";
let description = [{
Calculates the address using the base address and indices.
Load the value at the calculated address.
Example:
%value = neura.load_indexed %base [%arg1, %arg2] : f32
}];
let arguments = (ins Arg<AnyMemRef, "the load operation">:$base, Variadic<AnyType>:$indices, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
let assemblyFormat = "$base `[` $indices `:` type($indices) `]` type($base) ($predicate^ `:` type($predicate))? attr-dict `:` type($result)";
}

//Defines a store operation with integrated address calculation.
def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed", [AttrSizedOperandSegments]> {
let summary = "Store with integrated address calculation for multi-dimensional arrays";
let description = [{
Calculates the address using the base address and indices.
Store the value at the calculated address.
Example:
neura.store_indexed %value, %base [%arg1, %arg2] : f32
}];
let arguments = (ins AnyType:$value, Arg<AnyMemRef, "the store operation">:$base, Variadic<AnyType>:$indices, Optional<AnyType>:$predicate);
let results = (outs);
let assemblyFormat = "$value `to` $base `[` $indices `:` type($indices) `]` type($base) ($predicate^ `:` type($predicate))? attr-dict `:` type($value)";
}

// Defines a pointer computation operation.
def Neura_GEP : Op<NeuraDialect, "gep"> {
let summary = "Pointer computation using offset indices";
Expand All @@ -131,14 +159,14 @@ def Neura_CondBr : Op<NeuraDialect, "cond_br", [Terminator, AttrSizedOperandSegm
Variadic<AnyType>:$trueArgs,
Variadic<AnyType>:$falseArgs);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^)? `:` type($trueArgs) `to` $trueDest `else` ($falseArgs^)? `:` type($falseArgs) `to` $falseDest attr-dict";
let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^ `:` type($trueArgs))? `to` $trueDest `else` ($falseArgs^ `:` type($falseArgs))? `to` $falseDest attr-dict";
}

// Defines an unconditional branch operation.
def Neura_Br : Op<NeuraDialect, "br", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$args);
let successors = (successor AnySuccessor:$dest);
let assemblyFormat = "($args^)? `:` type($args) `to` $dest attr-dict";
let assemblyFormat = "($args^ `:` type($args))? `to` $dest attr-dict";
}

def Neura_SelOp : Op<NeuraDialect, "sel"> {
Expand Down
1 change: 0 additions & 1 deletion include/NeuraDialect/NeuraPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,4 @@ def MapToAccelerator : Pass<"map-to-accelerator", "ModuleOp"> {
}];
let constructor = "neura::createMapToAcceleratorPass()";
}

#endif // NEURA_PASSES_TD
85 changes: 54 additions & 31 deletions lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Common/AcceleratorAttrs.h"
#include "Conversion/ConversionPasses.h"
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
Expand All @@ -8,6 +9,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"

namespace mlir {
namespace neura {
Expand All @@ -24,9 +26,6 @@ using namespace mlir;
using namespace mlir::func;
using namespace mlir::neura;

#define GEN_PASS_DEF_LOWERARITHTONEURA
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {

struct ArithConstantToNeuraConstant
Expand All @@ -35,10 +34,10 @@ struct ArithConstantToNeuraConstant

LogicalResult matchAndRewrite(arith::ConstantOp op,
PatternRewriter &rewriter) const override {
// Converts arith constant to Neura constant
// Converts arith constant to Neura constant.
Type result_type = op.getType();
Attribute value = op.getValue();
// Optional predicate parameter can be null
// Optional predicate parameter can be null.
rewriter.replaceOpWithNewOp<neura::ConstantOp>(op, result_type, value,
nullptr);
return success();
Expand All @@ -54,7 +53,7 @@ struct ArithAddIToNeuraAdd : public OpRewritePattern<mlir::arith::AddIOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::AddOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -70,7 +69,7 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -86,7 +85,7 @@ struct ArithSubIToNeuraSub : public OpRewritePattern<mlir::arith::SubIOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::SubOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -102,7 +101,7 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -118,7 +117,7 @@ struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -134,7 +133,7 @@ struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::FDivOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand Down Expand Up @@ -185,8 +184,8 @@ struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
return rewriter.notifyMatchFailure(op, "Unsupported arith CmpIOp type");
}

// Convert arith CmpIOp to Neura ICmpOp
// Optional predicate: default to null
// Converts arith CmpIOp to Neura ICmpOp.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::ICmpOp>(
op, result_type, lhs, rhs, nullptr, rewriter.getStringAttr(cmp_type));
return success();
Expand All @@ -203,7 +202,7 @@ struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
Value false_value = op.getFalseValue();
Type result_type = op.getType();

// Convert arith SelectOp to Neura SelOp
// Converts arith SelectOp to Neura SelOp.
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, true_value,
false_value, condition);
return success();
Expand All @@ -218,8 +217,8 @@ struct ArithExtUIToNeuraCast : public OpRewritePattern<mlir::arith::ExtUIOp> {
Value input = op.getIn();
Type result_type = op.getType();

// Convert arith ExtUIOp to Neura cast operation
// Optional predicate: default to null
// Converts arith ExtUIOp to Neura cast operation.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("extui"), nullptr);
return success();
Expand All @@ -234,8 +233,8 @@ struct ArithExtfToNeuraCast : public OpRewritePattern<mlir::arith::ExtFOp> {
Value input = op.getIn();
Type result_type = op.getType();

// Convert arith ExtFOp to Neura cast operation
// Optional predicate: default to null
// Converts arith ExtFOp to Neura cast operation.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("extf"), nullptr);
return success();
Expand All @@ -250,11 +249,23 @@ struct ArithIndexCastToNeuraCast
PatternRewriter &rewriter) const override {
Value input = op.getIn();
Type result_type = op.getType();
Type in_type = input.getType();
StringRef cast_string;

// The isa<IntegerType> check is generic and handles any integer bit width.
// (e.g., i32, i64).
if (in_type.isIndex() && isa<IntegerType>(result_type)) {
cast_string = "index_to_int";
} else if (isa<IntegerType>(in_type) && result_type.isIndex()) {
cast_string = "int_to_index";
} else {
return rewriter.notifyMatchFailure(op, "index_cast");
}

// Convert arith IndexCastOp to Neura cast operation
// Optional predicate: default to null
// Converts arith IndexCastOp to Neura cast operation.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("indexCast"), nullptr);
op, result_type, input, rewriter.getStringAttr(cast_string), nullptr);
return success();
}
};
Expand All @@ -274,16 +285,28 @@ struct LowerArithToNeuraPass
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns
.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
ModuleOp module_op = getOperation();
MLIRContext *context = &getContext();
module_op.walk([&](func::FuncOp func_op) {
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
auto target =
func_op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp,
ArithSelectToNeuraSel, ArithExtUIToNeuraCast,
ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv,
ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
}
});
}
};
} // namespace
Expand Down
88 changes: 88 additions & 0 deletions lib/Conversion/BuiltinToNeura/BuiltinToNeuraPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "Common/AcceleratorAttrs.h"
#include "Conversion/ConversionPasses.h"
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::neura;

namespace {

struct BuiltinUnrealizedConversionCastToNeuraCast
: public OpRewritePattern<mlir::UnrealizedConversionCastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::UnrealizedConversionCastOp op,
PatternRewriter &rewriter) const override {
// Only handles simple 1:1 casts.
// TODO: Handle more complex casts if needed.
if (op.getInputs().size() == 1 && op.getResults().size() == 1) {
Value input = op.getInputs()[0];
Type result_type = op.getResults()[0].getType();
Type input_type = input.getType();

StringRef cast_type;
if (input_type.isIndex() && isa<IntegerType>(result_type)) {
cast_type = "index_to_int";
} else if (isa<IntegerType>(input_type) && result_type.isIndex()) {
cast_type = "int_to_index";
} else {
return rewriter.notifyMatchFailure(op, "unsupported cast");
}

// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr(cast_type), nullptr);
return success();
}
return failure();
}
};

struct LowerBuiltinToNeuraPass
: public PassWrapper<LowerBuiltinToNeuraPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBuiltinToNeuraPass)

StringRef getArgument() const override { return "lower-builtin-to-neura"; }
StringRef getDescription() const override {
return "Lower Builtin operations to Neura dialect operations";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::neura::NeuraDialect>();
}

void runOnOperation() override {
ModuleOp module_op = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.add<BuiltinUnrealizedConversionCastToNeuraCast>(context);
module_op.walk([&](func::FuncOp func_op) {
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
auto target =
func_op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) {
return signalPassFailure();
}
}
}
});
}
};
} // namespace

std::unique_ptr<Pass> mlir::createLowerBuiltinToNeuraPass() {
return std::make_unique<LowerBuiltinToNeuraPass>();
}
18 changes: 18 additions & 0 deletions lib/Conversion/BuiltinToNeura/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR})

add_mlir_conversion_library(MLIRNeuraBuiltinToNeuraPass
BuiltinToNeuraPass.cpp

DEPENDS
MLIRConversionIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRFuncDialect
MLIRLLVMDialect
MLIRIR
MLIRPass
MLIRTransforms
MLIRNeura
MLIRSupport
)
Loading