Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/Conversion/ConversionPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace mlir {
// Conversion passes.
std::unique_ptr<mlir::Pass> createLowerArithToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerLlvmToNeuraPass();
std::unique_ptr<mlir::Pass> createLowerMemRefToNeuraPass();

#define GEN_PASS_REGISTRATION
#include "Conversion/ConversionPasses.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions include/Conversion/ConversionPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,10 @@ def LowerLlvmToNeura : Pass<"lower-llvm-to-neura", "ModuleOp">{
let constructor = "mlir::createLowerLlvmToNeuraPass()";
}

def LowerMemRefToNeura : Pass<"lower-memref-to-neura", "ModuleOp">{
let summary = "Lower MemRef to Neura dialect";
let description = [{Lower MemRef operations to Neura dialect operations.}];
let constructor = "mlir::createLowerMemRefToNeuraPass()";
}

#endif // CONVERSION_PASSES_TD
25 changes: 24 additions & 1 deletion include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def Neura_AddOp : Op<NeuraDialect, "add"> {
let traits = [SameOperandsAndResultElementType];
}

def Neura_SubOp : Op<NeuraDialect, "sub"> {
let summary = "Integer substraction operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

// Defines a floating-point addition operation.
def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
let summary = "Floating addition operation";
Expand All @@ -38,7 +46,7 @@ def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
def Neura_FSubOp: Op<NeuraDialect, "fsub"> {
let summary = "Floating substraction operation";
let opName = "fsub";
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs);
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyFloat:$result);
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
Expand All @@ -54,6 +62,13 @@ def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
// let traits = [SameOperandsAndResultElementType];
}

def Neura_FDivOp : Op<NeuraDialect, "fdiv"> {
let summary = "Floating division operation";
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs, Optional<AnyType>:$predicate);
let results = (outs AnyFloat:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
}

// Defines a bitwise OR operation.
def Neura_OrOp : Op<NeuraDialect, "or"> {
let summary = "Bitwise OR operation";
Expand Down Expand Up @@ -144,6 +159,14 @@ def Neura_ReturnOp : Op<NeuraDialect, "return", [Terminator]> {
// let assemblyFormat = "($values^)? `,` $predicate attr-dict";
}

// Defines a cast operation for type conversion.
def Neura_CastOp : Op<NeuraDialect, "cast">{
let summary = "Generic type conversion operation";
let arguments = (ins AnyType:$input, StrAttr:$cast_type, Optional<AnyType>:$predicate);
let results = (outs AnyType:$result);
// let assemblyFormat = "$input type($input) `->` type($output) `,` $predicate attr-dict";
}

// ----------------------------------------------------
// Defines vector operations.

Expand Down
236 changes: 228 additions & 8 deletions lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "Conversion/ConversionPasses.h"
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "NeuraDialect/NeuraPasses.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "Conversion/ConversionPasses.h"

namespace mlir {
namespace neura {
Expand All @@ -26,7 +27,39 @@ using namespace mlir::neura;
#define GEN_PASS_DEF_LOWERARITHTONEURA
#include "NeuraDialect/NeuraPasses.h.inc"

namespace{
namespace {

struct ArithConstantToNeuraConstant
: public OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::ConstantOp op,
PatternRewriter &rewriter) const override {
// Converts arith constant to Neura constant
Type result_type = op.getType();
Attribute value = op.getValue();
// Optional predicate parameter can be null
rewriter.replaceOpWithNewOp<neura::ConstantOp>(op, result_type, value,
nullptr);
return success();
}
};

struct ArithAddIToNeuraAdd : public OpRewritePattern<mlir::arith::AddIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::AddIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::AddOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
using OpRewritePattern::OpRewritePattern;
Expand All @@ -35,16 +68,199 @@ struct ArithFAddToNeuraFAdd : public OpRewritePattern<mlir::arith::AddFOp> {
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type resultType = op.getType();
Type result_type = op.getType();

// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithSubIToNeuraSub : public OpRewritePattern<mlir::arith::SubIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::SubIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::SubOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::SubFOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithMulFToNeuraFMul : public OpRewritePattern<mlir::arith::MulFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MulFOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};

struct ArithFDivToNeuraFDiv : public OpRewritePattern<mlir::arith::DivFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::DivFOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();

// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::FDivOp>(op, result_type, lhs, rhs,
nullptr);
return success();
}
};
struct ArithCmpiToNeuraICmp : public OpRewritePattern<mlir::arith::CmpIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::CmpIOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Type result_type = op.getType();
arith::CmpIPredicate arith_cmp_type = op.getPredicate();
StringRef cmp_type;
switch (arith_cmp_type) {
case arith::CmpIPredicate::eq:
cmp_type = "eq"; // ==
break;
case arith::CmpIPredicate::ne:
cmp_type = "ne"; // !=
break;
case arith::CmpIPredicate::slt:
cmp_type = "slt"; // <
break;
case arith::CmpIPredicate::sle:
cmp_type = "sle"; // <=
break;
case arith::CmpIPredicate::sgt:
cmp_type = "sgt"; // >
break;
case arith::CmpIPredicate::sge:
cmp_type = "sge"; // >=
break;
case arith::CmpIPredicate::ult:
cmp_type = "ult"; // unsigned <
break;
case arith::CmpIPredicate::ule:
cmp_type = "ule"; // unsigned <=
break;
case arith::CmpIPredicate::ugt:
cmp_type = "ugt"; // unsigned >
break;
case arith::CmpIPredicate::uge:
cmp_type = "uge"; // unsigned >=
break;
default:
return rewriter.notifyMatchFailure(op, "Unsupported arith CmpIOp type");
}

// Convert arith CmpIOp to Neura ICmpOp
// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::ICmpOp>(
op, result_type, lhs, rhs, nullptr, rewriter.getStringAttr(cmp_type));
return success();
}
};

struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::SelectOp op,
PatternRewriter &rewriter) const override {
Value condition = op.getCondition();
Value true_value = op.getTrueValue();
Value false_value = op.getFalseValue();
Type result_type = op.getType();

// Convert arith SelectOp to Neura SelOp
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, true_value,
false_value, condition);
return success();
}
};

struct ArithExtUIToNeuraCast : public OpRewritePattern<mlir::arith::ExtUIOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::ExtUIOp op,
PatternRewriter &rewriter) const override {
Value input = op.getIn();
Type result_type = op.getType();

// Convert arith ExtUIOp to Neura cast operation
// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("extui"), nullptr);
return success();
}
};

struct ArithExtfToNeuraCast : public OpRewritePattern<mlir::arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const override {
Value input = op.getIn();
Type result_type = op.getType();

// Convert arith ExtFOp to Neura cast operation
// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("extf"), nullptr);
return success();
}
};

struct ArithIndexCastToNeuraCast
: public OpRewritePattern<mlir::arith::IndexCastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::IndexCastOp op,
PatternRewriter &rewriter) const override {
Value input = op.getIn();
Type result_type = op.getType();

// Optional predicate: default to 'none'
rewriter.replaceOpWithNewOp<neura::FAddOp>(op, resultType, lhs, rhs, Value());
// Convert arith IndexCastOp to Neura cast operation
// Optional predicate: default to null
rewriter.replaceOpWithNewOp<neura::CastOp>(
op, result_type, input, rewriter.getStringAttr("indexCast"), nullptr);
return success();
}
};

struct LowerArithToNeuraPass
: public PassWrapper<LowerArithToNeuraPass, OperationPass<func::FuncOp>> {
: public PassWrapper<LowerArithToNeuraPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerArithToNeuraPass)

Expand All @@ -60,7 +276,11 @@ struct LowerArithToNeuraPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
mlir::neura::arith2neura::populateWithGenerated(patterns);
patterns.add<ArithFAddToNeuraFAdd>(&getContext());
patterns
.add<ArithFAddToNeuraFAdd, ArithConstantToNeuraConstant,
ArithAddIToNeuraAdd, ArithCmpiToNeuraICmp, ArithSelectToNeuraSel,
ArithExtUIToNeuraCast, ArithIndexCastToNeuraCast,
ArithFDivToNeuraFDiv, ArithExtfToNeuraCast, ArithMulFToNeuraFMul, ArithSubIToNeuraSub, ArithSubFToNeuraFSub>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

add_subdirectory(ArithToNeura)
add_subdirectory(LlvmToNeura)
add_subdirectory(MemRefToNeura)

# add_mlir_library(
# MLIRNeuraConversion
Expand Down
18 changes: 18 additions & 0 deletions lib/Conversion/MemRefToNeura/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR})

add_mlir_conversion_library(MLIRNeuraMemRefToNeuraPass
MemRefToNeuraPass.cpp

DEPENDS
MLIRConversionIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRFuncDialect
MLIRLLVMDialect
MLIRIR
MLIRPass
MLIRTransforms
MLIRNeura
MLIRSupport
)
Loading
Loading