|
| 1 | +#include "Conversion/LlvmToNeura/LlvmToNeura.h" |
| 2 | +#include "NeuraDialect/NeuraDialect.h" |
| 3 | +#include "NeuraDialect/NeuraOps.h" |
| 4 | +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| 5 | +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 6 | +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 7 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 8 | +#include "mlir/IR/PatternMatch.h" |
| 9 | +#include "mlir/Pass/Pass.h" |
| 10 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 11 | + |
| 12 | +using namespace mlir; |
| 13 | + |
| 14 | +namespace { |
| 15 | +struct LlvmAddFOpLowering : public OpRewritePattern<mlir::LLVM::FAddOp> { |
| 16 | + using OpRewritePattern::OpRewritePattern; |
| 17 | + |
| 18 | + LogicalResult matchAndRewrite(mlir::LLVM::FAddOp op, |
| 19 | + PatternRewriter &rewriter) const override { |
| 20 | + rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(), op.getRhs()); |
| 21 | + return success(); |
| 22 | + } |
| 23 | +}; |
| 24 | + |
| 25 | +struct LowerLlvmToNeuraPass |
| 26 | + : public PassWrapper<LowerLlvmToNeuraPass, OperationPass<func::FuncOp>> { |
| 27 | + |
| 28 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerLlvmToNeuraPass) |
| 29 | + |
| 30 | + StringRef getArgument() const override { return "lower-llvm-to-neura"; } |
| 31 | + StringRef getDescription() const override { |
| 32 | + return "Lower LLVM operations to Neura dialect operations"; |
| 33 | + } |
| 34 | + |
| 35 | + void getDependentDialects(DialectRegistry ®istry) const override { |
| 36 | + registry.insert<mlir::neura::NeuraDialect>(); |
| 37 | + } |
| 38 | + |
| 39 | + void runOnOperation() override { |
| 40 | + RewritePatternSet patterns(&getContext()); |
| 41 | + patterns.add<LlvmAddFOpLowering>(&getContext()); |
| 42 | + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { |
| 43 | + signalPassFailure(); |
| 44 | + } |
| 45 | + } |
| 46 | +}; |
| 47 | +} // namespace |
| 48 | + |
| 49 | +std::unique_ptr<Pass> mlir::neura::createLowerLlvmToNeuraPass() { |
| 50 | + return std::make_unique<LowerLlvmToNeuraPass>(); |
| 51 | +} |
0 commit comments