Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def Neura_FMinOp : Op<NeuraDialect, "fmin"> {
}

// Defines a bitwise OR operation.
def Neura_AndOp : Op<NeuraDialect, "and"> {
let summary = "Bitwise AND operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
let results = (outs AnyType:$result);
// let assemblyFormat = "$lhs `,` $rhs `,` attr-dict `:` type($result)";
let traits = [SameOperandsAndResultElementType];
}

def Neura_OrOp : Op<NeuraDialect, "or"> {
let summary = "Bitwise OR operation";
let arguments = (ins AnyType:$lhs, AnyType:$rhs);
Expand Down Expand Up @@ -245,9 +253,9 @@ def Neura_Br : Op<NeuraDialect, "br", [Terminator]> {
}

def Neura_SelOp : Op<NeuraDialect, "sel"> {
let arguments = (ins AnyType:$ifTrue, AnyType:$ifFalse, AnyType:$cond);
let arguments = (ins AnyType:$cond, AnyType:$ifTrue, AnyType:$ifFalse);
let results = (outs AnyType:$result);
// let assemblyFormat = "$ifTrue `,` $ifFalse `,` $cond attr-dict `:` type($ifTrue)";
// let assemblyFormat = "$cond `,` $ifTrue `,` $ifFalse attr-dict `:` type($result)";
}

def Neura_NotOp : Op<NeuraDialect, "not"> {
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ struct ArithSelectToNeuraSel : public OpRewritePattern<mlir::arith::SelectOp> {
Value false_value = op.getFalseValue();
Type result_type = op.getType();

// Converts arith SelectOp to Neura SelOp.
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, true_value,
false_value, condition);
// Converts arith SelectOp to Neura SelOp with consistent order: (cond, ifTrue, ifFalse).
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type, condition,
true_value, false_value);
return success();
}
};
Expand Down
30 changes: 30 additions & 0 deletions lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ struct LlvmFSubToNeuraFSub : public OpRewritePattern<mlir::LLVM::FSubOp> {
}
};

struct LlvmAndToNeuraAnd : public OpRewritePattern<mlir::LLVM::AndOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::LLVM::AndOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<neura::AndOp>(op, op.getType(), op.getLhs(),
op.getRhs());
return success();
}
};

struct LlvmOrToNeuraOr : public OpRewritePattern<mlir::LLVM::OrOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -243,6 +254,23 @@ struct LlvmFPToSIToNeuraCast : public OpRewritePattern<mlir::LLVM::FPToSIOp> {
}
};

struct LlvmSelectToNeuraSel : public OpRewritePattern<LLVM::SelectOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(LLVM::SelectOp op,
PatternRewriter &rewriter) const override {
Value cond = op.getCondition();
Value true_value = op.getTrueValue();
Value false_value = op.getFalseValue();
Type result_type = op.getType();

// neura.sel now follows the same order as llvm.select: (cond, ifTrue, ifFalse)
rewriter.replaceOpWithNewOp<neura::SelOp>(op, result_type,
cond, true_value, false_value);
return success();
}
};

struct LlvmFMulAddToNeuraFMulFAdd : public OpRewritePattern<mlir::LLVM::FMulAddOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -723,6 +751,7 @@ struct LowerLlvmToNeuraPass
patterns.insert<LlvmVectorReduceAddToNeuraVectorReduceAdd>(&getContext());
// Scalar operations
patterns.add<LlvmAddToNeuraAdd>(&getContext());
patterns.add<LlvmAndToNeuraAnd>(&getContext());
patterns.add<LlvmOrToNeuraOr>(&getContext());
patterns.add<LlvmFAddToNeuraFAdd>(&getContext());
patterns.add<LlvmFMulToNeuraFMul>(&getContext());
Expand Down Expand Up @@ -752,6 +781,7 @@ struct LowerLlvmToNeuraPass
patterns.add<LlvmFDivToNeuraFDiv>(&getContext());
patterns.add<LlvmFPToSIToNeuraCast>(&getContext());
patterns.add<LlvmFMulAddToNeuraFMulFAdd>(&getContext());
patterns.add<LlvmSelectToNeuraSel>(&getContext());

FrozenRewritePatternSet frozen(std::move(patterns));

Expand Down
Loading