From 380bb441e7465052762e083da5d5c230205684f0 Mon Sep 17 00:00:00 2001 From: tancheng Date: Tue, 13 May 2025 16:43:11 +0000 Subject: [PATCH 1/2] [feature] Declare and define mov --- include/NeuraDialect/NeuraOps.td | 10 +++ include/Transforms/InsertMovPass.h | 13 ++++ lib/CMakeLists.txt | 1 + lib/Conversion/ArithToNeura/ArithToNeura.cpp | 7 -- lib/Transforms/CMakeLists.txt | 11 +++ lib/Transforms/InsertMovPass.cpp | 82 ++++++++++++++++++++ test/neura/add.mlir | 9 +++ tools/mlir-neura-opt/CMakeLists.txt | 1 + tools/mlir-neura-opt/mlir-neura-opt.cpp | 4 + 9 files changed, 131 insertions(+), 7 deletions(-) create mode 100644 include/Transforms/InsertMovPass.h create mode 100644 lib/Transforms/CMakeLists.txt create mode 100644 lib/Transforms/InsertMovPass.cpp create mode 100644 test/neura/add.mlir diff --git a/include/NeuraDialect/NeuraOps.td b/include/NeuraDialect/NeuraOps.td index 0e90ae17..ac83aa1e 100644 --- a/include/NeuraDialect/NeuraOps.td +++ b/include/NeuraDialect/NeuraOps.td @@ -11,3 +11,13 @@ def Neura_AddOp : Op { let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; // let traits = [Pure]; } + +// Defines a move operation for data communication. +def Neura_MovOp : Op { + let summary = "Move operation"; + let opName = "mov"; + let arguments = (ins AnyType:$lhs); + let results = (outs AnyType:$result); + let assemblyFormat = "$lhs attr-dict `:` type($lhs) `->` type($result)"; + // let traits = [Pure]; +} diff --git a/include/Transforms/InsertMovPass.h b/include/Transforms/InsertMovPass.h new file mode 100644 index 00000000..06a0befa --- /dev/null +++ b/include/Transforms/InsertMovPass.h @@ -0,0 +1,13 @@ +#ifndef NEURA_TRANSFORMS_INSERTMOVPASS_H +#define NEURA_TRANSFORMS_INSERTMOVPASS_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace neura { + std::unique_ptr createInsertMovPass(); +} // namespace neura +} // namespace mlir + +#endif // NEURA_TRANSFORMS_INSERTMOVPASS_H + diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 30d5c055..5487d736 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(NeuraDialect) add_subdirectory(Conversion/ArithToNeura) +add_subdirectory(Transforms) diff --git a/lib/Conversion/ArithToNeura/ArithToNeura.cpp b/lib/Conversion/ArithToNeura/ArithToNeura.cpp index cfb8c216..319ec5fe 100644 --- a/lib/Conversion/ArithToNeura/ArithToNeura.cpp +++ b/lib/Conversion/ArithToNeura/ArithToNeura.cpp @@ -15,12 +15,7 @@ struct ArithAddFOpLowering : public OpRewritePattern { LogicalResult matchAndRewrite(arith::AddFOp op, PatternRewriter &rewriter) const override { -llvm::errs() << "[cheng] step into matchAndRewriter()"; rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), op.getRhs()); - -llvm::errs() << "[cheng] Matched arith.addf: "; -// op.dump(); - return success(); } }; @@ -40,8 +35,6 @@ struct LowerArithToNeuraPass } void runOnOperation() override { - // getContext().loadDialect(); - RewritePatternSet patterns(&getContext()); llvm::errs() << "[cheng] check runOnOperation: "; getOperation().dump(); diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt new file mode 100644 index 00000000..b6448528 --- /dev/null +++ b/lib/Transforms/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(NeuraTransforms + InsertMovPass.cpp + + LINK_LIBS PUBLIC + MLIRIR + MLIRFuncDialect + MLIRPass + MLIRSupport + MLIRTransformUtils + NeuraDialect +) diff --git a/lib/Transforms/InsertMovPass.cpp b/lib/Transforms/InsertMovPass.cpp new file mode 100644 index 00000000..467e3387 --- /dev/null +++ b/lib/Transforms/InsertMovPass.cpp @@ -0,0 +1,82 @@ +#include "NeuraDialect/NeuraDialect.h" +#include "NeuraDialect/NeuraOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; + +namespace { +struct InsertMovForNeuraOps : public RewritePattern { + InsertMovForNeuraOps(MLIRContext *context) + : RewritePattern(/*matchAnyOpTypeTag=*/MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + // using RewritePattern::RewritePattern; + + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { + if (op->getDialect()->getNamespace() != "neura" || + isa(op)) { + return failure(); + } + + llvm::errs() << "[cheng] step into matching and rewrite"; + Location loc = op->getLoc(); + + // Wrap operands in mov + SmallVector newOperands; + for (Value operand : op->getOperands()) { + auto mov = rewriter.create(loc, operand.getType(), operand); + newOperands.push_back(mov); + } + + // Clone op with new operands + OperationState state(loc, op->getName()); + state.addOperands(newOperands); + state.addTypes(op->getResultTypes()); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.create(state); + + // Wrap each result in a mov + SmallVector newResults; + for (Value result : newOp->getResults()) { + auto mov = rewriter.create(loc, result.getType(), result); + newResults.push_back(mov); + } + + rewriter.replaceOp(op, newResults); + return success(); + } +}; + +struct InsertMovPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMovPass) + + StringRef getArgument() const override { return "insert-mov"; } + StringRef getDescription() const override { + return "Insert neura.mov before and after all neura dialect operations."; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace neura { + +std::unique_ptr createInsertMovPass() { + return std::make_unique(); +} + +} // namespace neura +} // namespace mlir diff --git a/test/neura/add.mlir b/test/neura/add.mlir new file mode 100644 index 00000000..a4564240 --- /dev/null +++ b/test/neura/add.mlir @@ -0,0 +1,9 @@ +// RN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s | FileCheck %s +// RUN: mlir-neura-opt --insert-mov %s | FileCheck %s + +func.func @test(%a: f32) -> f32 { + %b = arith.constant 2.0 : f32 + %res = arith.addf %a, %b : f32 + // CHECK: neura.add + return %res : f32 +} diff --git a/tools/mlir-neura-opt/CMakeLists.txt b/tools/mlir-neura-opt/CMakeLists.txt index 8f617649..08fa6db8 100644 --- a/tools/mlir-neura-opt/CMakeLists.txt +++ b/tools/mlir-neura-opt/CMakeLists.txt @@ -15,6 +15,7 @@ target_link_libraries(mlir-neura-opt PRIVATE MLIRFuncDialect # Builtin dialect required by custom dialect MLIRArithDialect NeuraArithToNeura + NeuraTransforms ) # Includes directories. diff --git a/tools/mlir-neura-opt/mlir-neura-opt.cpp b/tools/mlir-neura-opt/mlir-neura-opt.cpp index 85df2f1e..b10cba16 100644 --- a/tools/mlir-neura-opt/mlir-neura-opt.cpp +++ b/tools/mlir-neura-opt/mlir-neura-opt.cpp @@ -7,6 +7,7 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "Conversion/ArithToNeura/ArithToNeura.h" #include "NeuraDialect/NeuraDialect.h" +#include "Transforms/InsertMovPass.h" int main(int argc, char **argv) { // Registers MLIR dialects. @@ -18,6 +19,9 @@ int main(int argc, char **argv) { mlir::registerPass([]() -> std::unique_ptr { return mlir::neura::createLowerArithToNeuraPass(); }); + mlir::registerPass([]() -> std::unique_ptr { + return mlir::neura::createInsertMovPass(); + }); // Runs the MLIR optimizer. return mlir::asMainReturnCode( From 35992b2f3c8c1f017e7f09a96a802796cb22e6db Mon Sep 17 00:00:00 2001 From: tancheng Date: Wed, 14 May 2025 05:42:49 +0000 Subject: [PATCH 2/2] [update] Avoid adding mov on op's result & fix hanging --- lib/Transforms/InsertMovPass.cpp | 32 ++++++++++++++++++-------------- test/neura/add.mlir | 5 +++-- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/lib/Transforms/InsertMovPass.cpp b/lib/Transforms/InsertMovPass.cpp index 467e3387..0ff05656 100644 --- a/lib/Transforms/InsertMovPass.cpp +++ b/lib/Transforms/InsertMovPass.cpp @@ -12,39 +12,43 @@ struct InsertMovForNeuraOps : public RewritePattern { InsertMovForNeuraOps(MLIRContext *context) : RewritePattern(/*matchAnyOpTypeTag=*/MatchAnyOpTypeTag(), /*benefit=*/1, context) {} - // using RewritePattern::RewritePattern; - LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getDialect()->getNamespace() != "neura" || isa(op)) { return failure(); } - llvm::errs() << "[cheng] step into matching and rewrite"; + // Skips ops that already being inserted mov on the operands. + bool allInputsAreMov = llvm::all_of(op->getOperands(), [](Value v) { + return isa_and_nonnull(v.getDefiningOp()); + }); + if (allInputsAreMov) { + return failure(); + } + + // Makes sure none of the operand has being processed. + bool hasAnyMovInput = llvm::any_of(op->getOperands(), [](Value v) { + return isa_and_nonnull(v.getDefiningOp()); + }); + assert(!hasAnyMovInput && "Unexpected: operand already wrapped in neura.mov"); + Location loc = op->getLoc(); - // Wrap operands in mov + // Wraps operands in mov. SmallVector newOperands; for (Value operand : op->getOperands()) { auto mov = rewriter.create(loc, operand.getType(), operand); newOperands.push_back(mov); } - // Clone op with new operands + // Clones op with new operands. OperationState state(loc, op->getName()); state.addOperands(newOperands); state.addTypes(op->getResultTypes()); state.addAttributes(op->getAttrs()); - Operation *newOp = rewriter.create(state); - - // Wrap each result in a mov - SmallVector newResults; - for (Value result : newOp->getResults()) { - auto mov = rewriter.create(loc, result.getType(), result); - newResults.push_back(mov); - } - rewriter.replaceOp(op, newResults); + Operation *newOp = rewriter.create(state); + rewriter.replaceOp(op, newOp->getResults()); return success(); } }; diff --git a/test/neura/add.mlir b/test/neura/add.mlir index a4564240..df477461 100644 --- a/test/neura/add.mlir +++ b/test/neura/add.mlir @@ -1,9 +1,10 @@ -// RN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s | FileCheck %s -// RUN: mlir-neura-opt --insert-mov %s | FileCheck %s +// RUN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s | FileCheck %s func.func @test(%a: f32) -> f32 { %b = arith.constant 2.0 : f32 %res = arith.addf %a, %b : f32 + // CHECK: neura.mov %arg0 : f32 -> f32 + // CHECK: neura.mov %cst : f32 -> f32 // CHECK: neura.add return %res : f32 }