Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/NeuraDialect/NeuraOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ def Neura_AddOp : Op<NeuraDialect, "add"> {
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
// let traits = [Pure];
}

// Defines a move operation for data communication.
def Neura_MovOp : Op<NeuraDialect, "mov"> {
let summary = "Move operation";
let opName = "mov";
let arguments = (ins AnyType:$lhs);
let results = (outs AnyType:$result);
let assemblyFormat = "$lhs attr-dict `:` type($lhs) `->` type($result)";
// let traits = [Pure];
}
13 changes: 13 additions & 0 deletions include/Transforms/InsertMovPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef NEURA_TRANSFORMS_INSERTMOVPASS_H
#define NEURA_TRANSFORMS_INSERTMOVPASS_H

#include "mlir/Pass/Pass.h"

namespace mlir {
namespace neura {
std::unique_ptr<mlir::Pass> createInsertMovPass();
} // namespace neura
} // namespace mlir

#endif // NEURA_TRANSFORMS_INSERTMOVPASS_H

1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(NeuraDialect)
add_subdirectory(Conversion/ArithToNeura)
add_subdirectory(Transforms)
7 changes: 0 additions & 7 deletions lib/Conversion/ArithToNeura/ArithToNeura.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ struct ArithAddFOpLowering : public OpRewritePattern<arith::AddFOp> {

LogicalResult matchAndRewrite(arith::AddFOp op,
PatternRewriter &rewriter) const override {
llvm::errs() << "[cheng] step into matchAndRewriter()";
rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(), op.getRhs());

llvm::errs() << "[cheng] Matched arith.addf: ";
// op.dump();

return success();
}
};
Expand All @@ -40,8 +35,6 @@ struct LowerArithToNeuraPass
}

void runOnOperation() override {
// getContext().loadDialect<mlir::neura::NeuraDialect>();

RewritePatternSet patterns(&getContext());
llvm::errs() << "[cheng] check runOnOperation: ";
getOperation().dump();
Expand Down
11 changes: 11 additions & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
add_mlir_library(NeuraTransforms
InsertMovPass.cpp

LINK_LIBS PUBLIC
MLIRIR
MLIRFuncDialect
MLIRPass
MLIRSupport
MLIRTransformUtils
NeuraDialect
)
86 changes: 86 additions & 0 deletions lib/Transforms/InsertMovPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "NeuraDialect/NeuraDialect.h"
#include "NeuraDialect/NeuraOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

using namespace mlir;

namespace {
struct InsertMovForNeuraOps : public RewritePattern {
InsertMovForNeuraOps(MLIRContext *context)
: RewritePattern(/*matchAnyOpTypeTag=*/MatchAnyOpTypeTag(), /*benefit=*/1, context) {}

LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
if (op->getDialect()->getNamespace() != "neura" ||
isa<neura::MovOp>(op)) {
return failure();
}

// Skips ops that already being inserted mov on the operands.
bool allInputsAreMov = llvm::all_of(op->getOperands(), [](Value v) {
return isa_and_nonnull<neura::MovOp>(v.getDefiningOp());
});
if (allInputsAreMov) {
return failure();
}

// Makes sure none of the operand has being processed.
bool hasAnyMovInput = llvm::any_of(op->getOperands(), [](Value v) {
return isa_and_nonnull<neura::MovOp>(v.getDefiningOp());
});
assert(!hasAnyMovInput && "Unexpected: operand already wrapped in neura.mov");

Location loc = op->getLoc();

// Wraps operands in mov.
SmallVector<Value> newOperands;
for (Value operand : op->getOperands()) {
auto mov = rewriter.create<neura::MovOp>(loc, operand.getType(), operand);
newOperands.push_back(mov);
}

// Clones op with new operands.
OperationState state(loc, op->getName());
state.addOperands(newOperands);
state.addTypes(op->getResultTypes());
state.addAttributes(op->getAttrs());

Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};

struct InsertMovPass
: public PassWrapper<InsertMovPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMovPass)

StringRef getArgument() const override { return "insert-mov"; }
StringRef getDescription() const override {
return "Insert neura.mov before and after all neura dialect operations.";
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mlir::neura::NeuraDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<InsertMovForNeuraOps>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};
} // namespace

namespace mlir {
namespace neura {

std::unique_ptr<Pass> createInsertMovPass() {
return std::make_unique<InsertMovPass>();
}

} // namespace neura
} // namespace mlir
10 changes: 10 additions & 0 deletions test/neura/add.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: mlir-neura-opt --lower-arith-to-neura --insert-mov %s | FileCheck %s

func.func @test(%a: f32) -> f32 {
%b = arith.constant 2.0 : f32
%res = arith.addf %a, %b : f32
// CHECK: neura.mov %arg0 : f32 -> f32
// CHECK: neura.mov %cst : f32 -> f32
// CHECK: neura.add
return %res : f32
}
1 change: 1 addition & 0 deletions tools/mlir-neura-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ target_link_libraries(mlir-neura-opt PRIVATE
MLIRFuncDialect # Builtin dialect required by custom dialect
MLIRArithDialect
NeuraArithToNeura
NeuraTransforms
)

# Includes directories.
Expand Down
4 changes: 4 additions & 0 deletions tools/mlir-neura-opt/mlir-neura-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "Conversion/ArithToNeura/ArithToNeura.h"
#include "NeuraDialect/NeuraDialect.h"
#include "Transforms/InsertMovPass.h"

int main(int argc, char **argv) {
// Registers MLIR dialects.
Expand All @@ -18,6 +19,9 @@ int main(int argc, char **argv) {
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::neura::createLowerArithToNeuraPass();
});
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::neura::createInsertMovPass();
});

// Runs the MLIR optimizer.
return mlir::asMainReturnCode(
Expand Down