Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/Conversion/ConversionPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace mlir {
std::unique_ptr<mlir::Pass> createLowerArithToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerLlvmToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerMemRefToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerBuiltinToNeuraPass();

#define GEN_PASS_REGISTRATION
#include "Conversion/ConversionPasses.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions include/Conversion/ConversionPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,10 @@ def LowerMemRefToNeura : Pass<"lower-memref-to-neura", "ModuleOp">{
let constructor = "mlir::createLowerMemRefToNeuraPass()";
}

def LowerBuiltinToNeura : Pass<"lower-builtin-to-neura", "ModuleOp">{
let summary = "Lower Builtin to Neura dialect";
let description = [{Lower Builtin operations to Neura dialect operations.}];
let constructor = "mlir::createLowerBuiltinToNeuraPass()";
}

#endif // CONVERSION_PASSES_TD
32 changes: 30 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,34 @@ def Neura_StoreOp : Op<NeuraDialect, "store"> {
// let assemblyFormat = "$value `,` $addr `,` $predicate attr-dict";
}

// Defines a load operation with integrated address calculation.
def Neura_LoadIndexedOp: Op<NeuraDialect, "load_indexed", [AttrSizedOperandSegments]>{
let summary = "Load with integrated address calculation for multi-dimensional arrays";
let description = [{
Calculates the address using the base address and indices.
Load the value at the calculated address.
Example:
%value = neura.load_indexed %base [%arg1, %arg2] : f32
}];
let arguments = (ins Arg<AnyMemRef, "the load operation">:$base, Variadic<AnyType>:$indices, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
let assemblyFormat = "$base `[` $indices `:` type($indices) `]` type($base) ($predicate^ `:` type($predicate))? attr-dict `:` type($result)";
}

//Defines a store operation with integrated address calculation.
def Neura_StoreIndexedOp: Op<NeuraDialect, "store_indexed", [AttrSizedOperandSegments]> {
let summary = "Store with integrated address calculation for multi-dimensional arrays";
let description = [{
Calculates the address using the base address and indices.
Store the value at the calculated address.
Example:
neura.store_indexed %value, %base [%arg1, %arg2] : f32
}];
let arguments = (ins AnyType:$value, Arg<AnyMemRef, "the store operation">:$base, Variadic<AnyType>:$indices, Optional<AnyType>:$predicate);
let results = (outs);
let assemblyFormat = "$value `to` $base `[` $indices `:` type($indices) `]` type($base) ($predicate^ `:` type($predicate))? attr-dict `:` type($value)";
}

// Defines a pointer computation operation.
def Neura_GEP : Op<NeuraDialect, "gep"> {
let summary = "Pointer computation using offset indices";
Expand All @@ -131,14 +159,14 @@ def Neura_CondBr : Op<NeuraDialect, "cond_br", [Terminator, AttrSizedOperandSegm
Variadic<AnyType>:$trueArgs,
Variadic<AnyType>:$falseArgs);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^)? `:` type($trueArgs) `to` $trueDest `else` ($falseArgs^)? `:` type($falseArgs) `to` $falseDest attr-dict";
let assemblyFormat = "$condition `:` type($condition) ($predicate^ `:` type($predicate))? `then` ($trueArgs^ `:` type($trueArgs))? `to` $trueDest `else` ($falseArgs^ `:` type($falseArgs))? `to` $falseDest attr-dict";
}

// Defines an unconditional branch operation.
def Neura_Br : Op<NeuraDialect, "br", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$args);
let successors = (successor AnySuccessor:$dest);
let assemblyFormat = "($args^)? `:` type($args) `to` $dest attr-dict";
let assemblyFormat = "($args^ `:` type($args))? `to` $dest attr-dict";
}

def Neura_SelOp : Op<NeuraDialect, "sel"> {
Expand Down
1 change: 0 additions & 1 deletion include/NeuraDialect/NeuraPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,4 @@ def MapToAccelerator : Pass<"map-to-accelerator", "ModuleOp"> {
}];
let constructor = "neura::createMapToAcceleratorPass()";
}

#endif // NEURA_PASSES_TD
85 changes: 54 additions & 31 deletions lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Common/AcceleratorAttrs.h"
#include "Conversion/ConversionPasses.h"
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
Expand All @@ -8,6 +9,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"

namespace mlir {
namespace neura {
Expand All @@ -24,9 +26,6 @@ using namespace mlir;
using namespace mlir::func;
using namespace mlir::neura;

#define GEN_PASS_DEF_LOWERARITHTONEURA
#include "NeuraDialect/NeuraPasses.h.inc"

namespace {

struct ArithConstantToNeuraConstant
Expand All @@ -35,10 +34,10 @@ struct ArithConstantToNeuraConstant

LogicalResult matchAndRewrite(arith::ConstantOp op,
PatternRewriter &rewriter) const override {
// Converts arith constant to Neura constant
// Converts arith constant to Neura constant/
Type result_type = op.getType();
Attribute value = op.getValue();
// Optional predicate parameter can be null
// Optional predicate parameter can be null/
rewriter.replaceOpWithNewOp<neura::ConstantOp>(op, result_type, value,
nullptr);
return success();
Expand All @@ -54,7 +53,7 @@ struct ArithAddIToNeuraAdd : public OpRewritePattern<mlir::arith::AddIOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null/
rewriter.replaceOpWithNewOp<neura::AddOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -70,7 +69,7 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null/
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -86,7 +85,7 @@ struct ArithSubIToNeuraSub : public OpRewritePattern<mlir::arith::SubIOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null/
rewriter.replaceOpWithNewOp<neura::SubOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -102,7 +101,7 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null/
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -118,7 +117,7 @@ struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null/
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand All @@ -134,7 +133,7 @@ struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::FDivOp>(op, result_type, lhs, rhs,
nullptr);
return success();
Expand Down Expand Up @@ -185,8 +184,8 @@ struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
return rewriter.notifyMatchFailure(op, "Unsupported arith CmpIOp type");
}

// Convert arith CmpIOp to Neura ICmpOp
// Optional predicate: default to null
// Converts arith CmpIOp to Neura ICmpOp.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::ICmpOp>(
op, result_type, lhs, rhs, nullptr, rewriter.getStringAttr(cmp_type));
return success();
Expand All @@ -203,7 +202,7 @@ struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
Value false_value = op.getFalseValue();
Type result_type = op.getType();

// Convert arith SelectOp to Neura SelOp
// Converts arith SelectOp to Neura SelOp.
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, true_value,
false_value, condition);
return success();
Expand All @@ -218,8 +217,8 @@ struct ArithExtUIToNeuraCast : public OpRewritePattern<mlir::arith::ExtUIOp> {
Value input = op.getIn();
Type result_type = op.getType();

// Convert arith ExtUIOp to Neura cast operation
// Optional predicate: default to null
// Converts arith ExtUIOp to Neura cast operation.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("extui"), nullptr);
return success();
Expand All @@ -234,8 +233,8 @@ struct ArithExtfToNeuraCast : public OpRewritePattern<mlir::arith::ExtFOp> {
Value input = op.getIn();
Type result_type = op.getType();

// Convert arith ExtFOp to Neura cast operation
// Optional predicate: default to null
// Converts arith ExtFOp to Neura cast operation.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("extf"), nullptr);
return success();
Expand All @@ -250,11 +249,23 @@ struct ArithIndexCastToNeuraCast
PatternRewriter &rewriter) const override {
Value input = op.getIn();
Type result_type = op.getType();
Type in_type = input.getType();
StringRef cast_string;

// The isa<IntegerType> check is generic and handles any integer bit width.
// (e.g., i32, i64).
if (in_type.isIndex() && isa<IntegerType>(result_type)) {
cast_string = "index_to_int";
} else if (isa<IntegerType>(in_type) && result_type.isIndex()) {
cast_string = "int_to_index";
} else {
return rewriter.notifyMatchFailure(op, "index_cast");
}

// Convert arith IndexCastOp to Neura cast operation
// Optional predicate: default to null
// Converts arith IndexCastOp to Neura cast operation.
// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("indexCast"), nullptr);
op, result_type, input, rewriter.getStringAttr(cast_string), nullptr);
return success();
}
};
Expand All @@ -274,16 +285,28 @@ struct LowerArithToNeuraPass
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns
.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
ModuleOp module_op = getOperation();
MLIRContext *context = &getContext();
module_op.walk([&](func::FuncOp func_op) {
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
auto target =
func_op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp,
ArithSelectToNeuraSel, ArithExtUIToNeuraCast,
ArithIndexCastToNeuraCast, ArithFDivToNeuraFDiv,
ArithExtfToNeuraCast, ArithMulFToNeuraFMul,
ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}
}
});
}
};
} // namespace
Expand Down
88 changes: 88 additions & 0 deletions lib/Conversion/BuiltinToNeura/BuiltinToNeuraPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "Common/AcceleratorAttrs.h"
#include "Conversion/ConversionPasses.h"
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::neura;

namespace {

struct BuiltinUnrealizedConversionCastToNeuraCast
: public OpRewritePattern<mlir::UnrealizedConversionCastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::UnrealizedConversionCastOp op,
PatternRewriter &rewriter) const override {
// Only handles simple 1:1 casts.
// TODO: Handle more complex casts if needed.
if (op.getInputs().size() == 1 && op.getResults().size() == 1) {
Value input = op.getInputs()[0];
Type result_type = op.getResults()[0].getType();
Type input_type = input.getType();

StringRef cast_type;
if (input_type.isIndex() && isa<IntegerType>(result_type)) {
cast_type = "index_to_int";
} else if (isa<IntegerType>(input_type) && result_type.isIndex()) {
cast_type = "int_to_index";
} else {
return rewriter.notifyMatchFailure(op, "unsupported cast");
}

// Optional predicate: default to null.
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr(cast_type), nullptr);
return success();
}
return failure();
}
};

struct LowerBuiltinToNeuraPass
: public PassWrapper<LowerBuiltinToNeuraPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerBuiltinToNeuraPass)

StringRef getArgument() const override { return "lower-builtin-to-neura"; }
StringRef getDescription() const override {
return "Lower Builtin operations to Neura dialect operations";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::neura::NeuraDialect>();
}

void runOnOperation() override {
ModuleOp module_op = getOperation();
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.add<BuiltinUnrealizedConversionCastToNeuraCast>(context);
module_op.walk([&](func::FuncOp func_op) {
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
auto target =
func_op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) {
return signalPassFailure();
}
}
}
});
}
};
} // namespace

std::unique_ptr<Pass> mlir::createLowerBuiltinToNeuraPass() {
return std::make_unique<LowerBuiltinToNeuraPass>();
}
18 changes: 18 additions & 0 deletions lib/Conversion/BuiltinToNeura/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR})

add_mlir_conversion_library(MLIRNeuraBuiltinToNeuraPass
BuiltinToNeuraPass.cpp

DEPENDS
MLIRConversionIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRFuncDialect
MLIRLLVMDialect
MLIRIR
MLIRPass
MLIRTransforms
MLIRNeura
MLIRSupport
)
Loading