@@ -8,40 +8,116 @@ using namespace mlir;
88
99namespace {
1010
11- struct FuseFAddFAddPattern : public RewritePattern {
12- FuseFAddFAddPattern (MLIRContext *ctx)
13- : RewritePattern(" neura.fadd" , /* benefit=*/ 1 , ctx) {}
11+ struct FuseFAddFAddPattern : public OpRewritePattern <neura::FAddOp> {
12+ using OpRewritePattern::OpRewritePattern;
1413
15- LogicalResult matchAndRewrite (Operation *op, PatternRewriter &rewriter) const override {
16- auto first = dyn_cast<neura::FAddOp>(op);
17- if (!first || !first->hasOneUse ()) return failure ();
14+ LogicalResult matchAndRewrite (neura::FAddOp second,
15+ PatternRewriter &rewriter) const override {
16+ Value lhs = second.getLhs ();
17+ Value rhs = second.getRhs ();
1818
19- auto user = dyn_cast <neura::FAddOp>(*first-> getUsers (). begin () );
20- if (!user) return failure ();
19+ auto lhs_op = lhs. getDefiningOp <neura::FAddOp>();
20+ auto rhs_op = rhs. getDefiningOp <neura::FAddOp> ();
2121
22- Location loc = user. getLoc () ;
23- Type type = user. getType () ;
22+ neura::FAddOp first = nullptr ;
23+ Value tail ;
2424
25- auto fused = rewriter.create <neura::FAddFAddOp>(loc, type,
26- first.getLhs (), first.getRhs (), user.getRhs ());
25+ // Case 1: LHS is another fadd.
26+ if (lhs_op && second.getRhs ()) {
27+ first = lhs_op;
28+ tail = second.getRhs ();
29+ }
30+ // Case 2: RHS is another fadd.
31+ else if (rhs_op && second.getLhs ()) {
32+ first = rhs_op;
33+ tail = second.getLhs ();
34+ }
2735
28- rewriter.replaceOp (user, fused.getResult ());
36+ if (!first)
37+ return failure ();
38+
39+ // Only fuses if the first fadd is not reused elsewhere.
40+ if (!first->hasOneUse ())
41+ return failure ();
42+
43+ Location loc = second.getLoc ();
44+ Type type = second.getType ();
45+
46+ auto fused = rewriter.create <neura::FAddFAddOp>(
47+ loc, type, first.getLhs (), first.getRhs (), tail);
48+
49+ rewriter.replaceOp (second, fused.getResult ());
2950 rewriter.eraseOp (first);
3051 return success ();
3152 }
3253};
3354
34- struct FusePatternsPass : public PassWrapper <FusePatternsPass, OperationPass<func::FuncOp>> {
55+ struct FuseFMulFAddPattern : public OpRewritePattern <neura::FAddOp> {
56+ using OpRewritePattern::OpRewritePattern;
57+
58+ LogicalResult matchAndRewrite (neura::FAddOp add,
59+ PatternRewriter &rewriter) const override {
60+ auto lhs_op = add.getLhs ().getDefiningOp <neura::FMulOp>();
61+ auto rhs_op = add.getRhs ().getDefiningOp <neura::FMulOp>();
62+
63+ neura::FMulOp fmul = nullptr ;
64+ Value other;
65+
66+ // Case 1: fmul is on the LHS.
67+ if (lhs_op && add.getRhs ()) {
68+ fmul = lhs_op;
69+ other = add.getRhs ();
70+ }
71+ // Case 2: fmul is on the RHS.
72+ else if (rhs_op && add.getLhs ()) {
73+ fmul = rhs_op;
74+ other = add.getLhs ();
75+ }
76+
77+ if (!fmul)
78+ return failure ();
79+
80+ // Optional: only fuses if fmul has a single use.
81+ if (!fmul->hasOneUse ())
82+ return failure ();
83+
84+ Location loc = add.getLoc ();
85+ Type type = add.getType ();
86+
87+ auto fused = rewriter.create <neura::FMulFAddOp>(
88+ loc, type, fmul.getLhs (), fmul.getRhs (), other);
89+
90+ rewriter.replaceOp (add, fused.getResult ());
91+ rewriter.eraseOp (fmul);
92+ return success ();
93+ }
94+ };
95+
96+ struct FusePatternsPass : public PassWrapper <FusePatternsPass, OperationPass<ModuleOp>> {
3597 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (FusePatternsPass)
3698
3799 StringRef getArgument () const override { return " fuse-patterns" ; }
38100 StringRef getDescription () const override { return " Apply Neura fusion patterns." ; }
39101
40102 void runOnOperation () override {
41103 RewritePatternSet patterns (&getContext ());
42- patterns.add <FuseFAddFAddPattern>(&getContext ());
43- if (failed (applyPatternsAndFoldGreedily (getOperation (), std::move (patterns))))
44- signalPassFailure ();
104+ patterns.add <FuseFAddFAddPattern>(&getContext (), 2 );
105+ patterns.add <FuseFMulFAddPattern>(&getContext (), 3 );
106+ FrozenRewritePatternSet frozen (std::move (patterns));
107+
108+ ModuleOp module_op = getOperation ();
109+
110+ // Applies to every region inside the module (regardless of func type,
111+ // e.g., mlir func or llvm func).
112+ module_op.walk ([&](Operation *op) {
113+ if (!op->getRegions ().empty ()) {
114+ for (Region ®ion : op->getRegions ()) {
115+ if (failed (applyPatternsAndFoldGreedily (region, frozen))) {
116+ signalPassFailure ();
117+ }
118+ }
119+ }
120+ });
45121 }
46122};
47123
0 commit comments