Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def Neura_FMinOp : Op<NeuraDialect, "fmin"> {
}

// Defines a bitwise OR operation.
def Neura_AndOp : Op<NeuraDialect, "and"> {
let summary = "Bitwise AND operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

def Neura_OrOp : Op<NeuraDialect, "or"> {
let summary = "Bitwise OR operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
Expand Down
30 changes: 30 additions & 0 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ struct LlvmFSubToNeuraFSub : public OpRewritePattern<mlir::LLVM::FSubOp> {
}
};

struct LlvmAndToNeuraAnd : public OpRewritePattern<mlir::LLVM::AndOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::LLVM::AndOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<neura::AndOp>(op, op.getType(), op.getLhs(),
op.getRhs());
return success();
}
};

struct LlvmOrToNeuraOr : public OpRewritePattern<mlir::LLVM::OrOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -243,6 +254,23 @@ struct LlvmFPToSIToNeuraCast : public OpRewritePattern<mlir::LLVM::FPToSIOp> {
}
};

struct LlvmSelectToNeuraSel : public OpRewritePattern<LLVM::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::SelectOp op,
PatternRewriter &rewriter) const override {
Value cond = op.getCondition();
Value true_value = op.getTrueValue();
Value false_value = op.getFalseValue();
Type result_type = op.getType();

// Note: neura.sel has different argument order: (ifTrue, ifFalse, cond)
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type,
true_value, false_value, cond);
return success();
}
};

struct LlvmFMulAddToNeuraFMulFAdd : public OpRewritePattern<mlir::LLVM::FMulAddOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -723,6 +751,7 @@ struct LowerLlvmToNeuraPass
patterns.insert<LlvmVectorReduceAddToNeuraVectorReduceAdd>(&getContext());
// Scalar operations
patterns.add<LlvmAddToNeuraAdd>(&getContext());
patterns.add<LlvmAndToNeuraAnd>(&getContext());
patterns.add<LlvmOrToNeuraOr>(&getContext());
patterns.add<LlvmFAddToNeuraFAdd>(&getContext());
patterns.add<LlvmFMulToNeuraFMul>(&getContext());
Expand Down Expand Up @@ -752,6 +781,7 @@ struct LowerLlvmToNeuraPass
patterns.add<LlvmFDivToNeuraFDiv>(&getContext());
patterns.add<LlvmFPToSIToNeuraCast>(&getContext());
patterns.add<LlvmFMulAddToNeuraFMulFAdd>(&getContext());
patterns.add<LlvmSelectToNeuraSel>(&getContext());

FrozenRewritePatternSet frozen(std::move(patterns));

Expand Down
Loading