Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ def Neura_AddOp : Op<NeuraDialect, "add"> {
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
// let traits = [Pure];
}

// Defines a move operation for data communication.
def Neura_MovOp : Op<NeuraDialect, "mov"> {
let summary = "Move operation";
let opName = "mov";
let arguments = (ins AnyType:$lhs);
let results = (outs AnyType:$result);
let assemblyFormat = "$lhs attr-dict `:` type($lhs) `->` type($result)";
// let traits = [Pure];
}
13 changes: 13 additions & 0 deletions include/Transforms/InsertMovPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef NEURA_TRANSFORMS_INSERTMOVPASS_H
#define NEURA_TRANSFORMS_INSERTMOVPASS_H

#include "mlir/Pass/Pass.h"

namespace mlir {
namespace neura {
std::unique_ptr<mlir::Pass> createInsertMovPass();
} // namespace neura
} // namespace mlir

#endif // NEURA_TRANSFORMS_INSERTMOVPASS_H

1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(NeuraDialect)
add_subdirectory(Conversion/ArithToNeura)
add_subdirectory(Transforms)
7 changes: 0 additions & 7 deletions lib/Conversion/ArithToNeura/ArithToNeura.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ struct ArithAddFOpLowering : public OpRewritePattern<arith::AddFOp> {

LogicalResult matchAndRewrite(arith::AddFOp op,
PatternRewriter &rewriter) const override {
llvm::errs() << "[cheng] step into matchAndRewriter()";
rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(), op.getRhs());

llvm::errs() << "[cheng] Matched arith.addf: ";
// op.dump();

return success();
}
};
Expand All @@ -40,8 +35,6 @@ struct LowerArithToNeuraPass
}

void runOnOperation() override {
// getContext().loadDialect<mlir::neura::NeuraDialect>();

RewritePatternSet patterns(&getContext());
llvm::errs() << "[cheng] check runOnOperation: ";
getOperation().dump();
Expand Down
11 changes: 11 additions & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_mlir_library(NeuraTransforms
InsertMovPass.cpp

LINK_LIBS PUBLIC
MLIRIR
MLIRFuncDialect
MLIRPass
MLIRSupport
MLIRTransformUtils
NeuraDialect
)
82 changes: 82 additions & 0 deletions lib/Transforms/InsertMovPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

using namespace mlir;

namespace {
struct InsertMovForNeuraOps : public RewritePattern {
InsertMovForNeuraOps(MLIRContext *context)
: RewritePattern(/*matchAnyOpTypeTag=*/MatchAnyOpTypeTag(), /*benefit=*/1, context) {}

// using RewritePattern::RewritePattern;

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
if (op->getDialect()->getNamespace() != "neura" ||
isa<neura::MovOp>(op)) {
return failure();
}

llvm::errs() << "[cheng] step into matching and rewrite";
Location loc = op->getLoc();

// Wrap operands in mov
SmallVector<Value> newOperands;
for (Value operand : op->getOperands()) {
auto mov = rewriter.create<neura::MovOp>(loc, operand.getType(), operand);
newOperands.push_back(mov);
}

// Clone op with new operands
OperationState state(loc, op->getName());
state.addOperands(newOperands);
state.addTypes(op->getResultTypes());
state.addAttributes(op->getAttrs());
Operation *newOp = rewriter.create(state);

// Wrap each result in a mov
SmallVector<Value> newResults;
for (Value result : newOp->getResults()) {
auto mov = rewriter.create<neura::MovOp>(loc, result.getType(), result);
newResults.push_back(mov);
}

rewriter.replaceOp(op, newResults);
return success();
}
};

struct InsertMovPass
: public PassWrapper<InsertMovPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMovPass)

StringRef getArgument() const override { return "insert-mov"; }
StringRef getDescription() const override {
return "Insert neura.mov before and after all neura dialect operations.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::neura::NeuraDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<InsertMovForNeuraOps>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace

namespace mlir {
namespace neura {

std::unique_ptr<Pass> createInsertMovPass() {
return std::make_unique<InsertMovPass>();
}

} // namespace neura
} // namespace mlir
9 changes: 9 additions & 0 deletions test/neura/add.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s | FileCheck %s
// RUN: mlir-neura-opt --insert-mov %s | FileCheck %s

func.func @test(%a: f32) -> f32 {
%b = arith.constant 2.0 : f32
%res = arith.addf %a, %b : f32
// CHECK: neura.add
return %res : f32
}
1 change: 1 addition & 0 deletions tools/mlir-neura-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ target_link_libraries(mlir-neura-opt PRIVATE
MLIRFuncDialect # Builtin dialect required by custom dialect
MLIRArithDialect
NeuraArithToNeura
NeuraTransforms
)

# Includes directories.
Expand Down
4 changes: 4 additions & 0 deletions tools/mlir-neura-opt/mlir-neura-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "Conversion/ArithToNeura/ArithToNeura.h"
#include "NeuraDialect/NeuraDialect.h"
#include "Transforms/InsertMovPass.h"

int main(int argc, char **argv) {
// Registers MLIR dialects.
Expand All @@ -18,6 +19,9 @@ int main(int argc, char **argv) {
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::neura::createLowerArithToNeuraPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::neura::createInsertMovPass();
});

// Runs the MLIR optimizer.
return mlir::asMainReturnCode(
Expand Down