-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathFusePatternsPass.cpp
More file actions
131 lines (102 loc) · 3.61 KB
/
FusePatternsPass.cpp
File metadata and controls
131 lines (102 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "NeuraDialect/NeuraOps.h"
using namespace mlir;
namespace {
struct FuseFAddFAddPattern : public OpRewritePattern<neura::FAddOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(neura::FAddOp second,
PatternRewriter &rewriter) const override {
Value lhs = second.getLhs();
Value rhs = second.getRhs();
auto lhs_op = lhs.getDefiningOp<neura::FAddOp>();
auto rhs_op = rhs.getDefiningOp<neura::FAddOp>();
neura::FAddOp first = nullptr;
Value tail;
// Case 1: LHS is another fadd.
if (lhs_op && second.getRhs()) {
first = lhs_op;
tail = second.getRhs();
}
// Case 2: RHS is another fadd.
else if (rhs_op && second.getLhs()) {
first = rhs_op;
tail = second.getLhs();
}
if (!first)
return failure();
// Only fuses if the first fadd is not reused elsewhere.
if (!first->hasOneUse())
return failure();
Location loc = second.getLoc();
Type type = second.getType();
auto fused = rewriter.create<neura::FAddFAddOp>(
loc, type, first.getLhs(), first.getRhs(), tail);
rewriter.replaceOp(second, fused.getResult());
rewriter.eraseOp(first);
return success();
}
};
struct FuseFMulFAddPattern : public OpRewritePattern<neura::FAddOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(neura::FAddOp add,
PatternRewriter &rewriter) const override {
auto lhs_op = add.getLhs().getDefiningOp<neura::FMulOp>();
auto rhs_op = add.getRhs().getDefiningOp<neura::FMulOp>();
neura::FMulOp fmul = nullptr;
Value other;
// Case 1: fmul is on the LHS.
if (lhs_op && add.getRhs()) {
fmul = lhs_op;
other = add.getRhs();
}
// Case 2: fmul is on the RHS.
else if (rhs_op && add.getLhs()) {
fmul = rhs_op;
other = add.getLhs();
}
if (!fmul)
return failure();
// Optional: only fuses if fmul has a single use.
if (!fmul->hasOneUse())
return failure();
Location loc = add.getLoc();
Type type = add.getType();
auto fused = rewriter.create<neura::FMulFAddOp>(
loc, type, fmul.getLhs(), fmul.getRhs(), other);
rewriter.replaceOp(add, fused.getResult());
rewriter.eraseOp(fmul);
return success();
}
};
struct FusePatternsPass : public PassWrapper<FusePatternsPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FusePatternsPass)
StringRef getArgument() const override { return "fuse-patterns"; }
StringRef getDescription() const override { return "Apply Neura fusion patterns."; }
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
FrozenRewritePatternSet frozen(std::move(patterns));
ModuleOp module_op = getOperation();
// Applies to every region inside the module (regardless of func type,
// e.g., mlir func or llvm func).
module_op.walk([&](Operation *op) {
if (!op->getRegions().empty()) {
for (Region ®ion : op->getRegions()) {
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
signalPassFailure();
}
}
}
});
}
};
} // namespace
namespace mlir::neura {
std::unique_ptr<Pass> createFusePatternsPass() {
return std::make_unique<FusePatternsPass>();
}
} // namespace mlir::neura