Skip to content

Commit 6441046

Browse files
authored
Merge pull request #14 from coredac/lower_llvm_func
Enable pass running on module
2 parents ac03483 + 9b9f802 commit 6441046

11 files changed

Lines changed: 288 additions & 26 deletions

File tree

include/NeuraDialect/NeuraOps.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,35 @@ def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
2222
let traits = [SameOperandsAndResultElementType];
2323
}
2424

25+
// Defines a multiplication operation.
26+
def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
27+
let summary = "Floating multiplication operation";
28+
let opName = "fmul";
29+
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs);
30+
let results = (outs AnyFloat:$result);
31+
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
32+
let traits = [SameOperandsAndResultElementType];
33+
}
34+
35+
def VectorOfAnyFloat :
36+
TypeConstraint<
37+
CPred<
38+
"mlir::isa<::mlir::VectorType>($_self) && "
39+
"mlir::isa<::mlir::FloatType>(mlir::cast<::mlir::VectorType>($_self).getElementType())"
40+
>,
41+
"vector of floats"
42+
>;
43+
44+
// Defines a vector multiplication operation.
45+
def Neura_VFMulOp : Op<NeuraDialect, "vfmul"> {
46+
let summary = "Vector floating multiplication operation";
47+
let opName = "vfmul";
48+
let arguments = (ins VectorOfAnyFloat:$lhs, VectorOfAnyFloat:$rhs);
49+
let results = (outs VectorOfAnyFloat:$result);
50+
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
51+
let traits = [SameOperandsAndResultElementType];
52+
}
53+
2554
def Neura_FAddFAddOp : Op<NeuraDialect, "fadd_fadd"> {
2655
let summary = "Fused fadd(fadd(a, b), c)";
2756
let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c);
@@ -30,6 +59,13 @@ def Neura_FAddFAddOp : Op<NeuraDialect, "fadd_fadd"> {
3059
let traits = [SameOperandsAndResultElementType];
3160
}
3261

62+
def Neura_FMulFAddOp : Op<NeuraDialect, "fmul_fadd"> {
63+
let summary = "Fused fadd(fmul(a, b), c)";
64+
let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c);
65+
let results = (outs AnyFloat:$result);
66+
// let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)";
67+
let traits = [SameOperandsAndResultElementType];
68+
}
3369

3470
// Defines a move operation for data communication.
3571
def Neura_MovOp : Op<NeuraDialect, "mov"> {

lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,59 @@ using namespace mlir;
2424

2525
namespace {
2626

27+
// Lowers integer add from mlir.llvm.add to nuera.add. We provide the lowering
28+
// here instead of tablegen due to that mlir.llvm.add uses an EnumProperty
29+
// (IntegerOverflowFlags) defined via MLIR interfaces — which DRR cannot match
30+
// on or extract from.
31+
struct LlvmAddToNeuraAdd : public OpRewritePattern<mlir::LLVM::AddOp> {
32+
using OpRewritePattern::OpRewritePattern;
33+
34+
LogicalResult matchAndRewrite(mlir::LLVM::AddOp op,
35+
PatternRewriter &rewriter) const override {
36+
rewriter.replaceOpWithNewOp<neura::AddOp>(op, op.getType(), op.getLhs(), op.getRhs());
37+
return success();
38+
}
39+
};
40+
41+
struct LlvmFMulToNeuraFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
42+
using OpRewritePattern::OpRewritePattern;
43+
44+
LogicalResult matchAndRewrite(mlir::LLVM::FMulOp op,
45+
PatternRewriter &rewriter) const override {
46+
Value lhs = op->getOperand(0);
47+
Value rhs = op->getOperand(1);
48+
Type result_type = op->getResult(0).getType();
49+
50+
// Only matches scalar float.
51+
if (!mlir::isa<FloatType>(result_type))
52+
return failure();
53+
54+
rewriter.replaceOpWithNewOp<neura::FMulOp>(op, result_type, lhs, rhs);
55+
return success();
56+
}
57+
};
58+
59+
struct LlvmVFMulToNeuraVFMul: public OpRewritePattern<mlir::LLVM::FMulOp> {
60+
using OpRewritePattern::OpRewritePattern;
61+
62+
LogicalResult matchAndRewrite(mlir::LLVM::FMulOp op,
63+
PatternRewriter &rewriter) const override {
64+
Value lhs = op->getOperand(0);
65+
Value rhs = op->getOperand(1);
66+
Type result_type = op->getResult(0).getType();
67+
68+
// Only matches vector<xf32>.
69+
auto vecTy = mlir::dyn_cast<VectorType>(result_type);
70+
if (!vecTy || !mlir::isa<FloatType>(vecTy.getElementType()))
71+
return failure();
72+
73+
rewriter.replaceOpWithNewOp<neura::VFMulOp>(op, result_type, lhs, rhs);
74+
return success();
75+
}
76+
};
77+
2778
struct LowerLlvmToNeuraPass
28-
: public PassWrapper<LowerLlvmToNeuraPass, OperationPass<func::FuncOp>> {
79+
: public PassWrapper<LowerLlvmToNeuraPass, OperationPass<ModuleOp>> {
2980

3081
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerLlvmToNeuraPass)
3182

@@ -40,10 +91,26 @@ struct LowerLlvmToNeuraPass
4091

4192
void runOnOperation() override {
4293
RewritePatternSet patterns(&getContext());
94+
// Adds DRR patterns.
4395
mlir::neura::llvm2neura::populateWithGenerated(patterns);
44-
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
45-
signalPassFailure();
46-
}
96+
patterns.add<LlvmAddToNeuraAdd>(&getContext());
97+
patterns.add<LlvmFMulToNeuraFMul>(&getContext());
98+
patterns.add<LlvmVFMulToNeuraVFMul>(&getContext());
99+
FrozenRewritePatternSet frozen(std::move(patterns));
100+
101+
ModuleOp module_op = getOperation();
102+
103+
// Applies to every region inside the module (regardless of func type,
104+
// e.g., mlir func or llvm func).
105+
module_op.walk([&](Operation *op) {
106+
if (!op->getRegions().empty()) {
107+
for (Region &region : op->getRegions()) {
108+
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
109+
signalPassFailure();
110+
}
111+
}
112+
}
113+
});
47114
}
48115
};
49116
} // namespace

lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ def : Pat<
77
(LLVM_FAddOp $lhs, $rhs, $_fastmath),
88
(Neura_FAddOp $lhs, $rhs)
99
>;
10+

lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_library(NeuraTransforms
1616
LINK_LIBS PUBLIC
1717
MLIRIR
1818
MLIRFuncDialect
19+
MLIRLLVMDialect
1920
MLIRPass
2021
MLIRSupport
2122
MLIRTransformUtils

lib/Transforms/FusePatternsPass.cpp

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,116 @@ using namespace mlir;
88

99
namespace {
1010

11-
struct FuseFAddFAddPattern : public RewritePattern {
12-
FuseFAddFAddPattern(MLIRContext *ctx)
13-
: RewritePattern("neura.fadd", /*benefit=*/1, ctx) {}
11+
struct FuseFAddFAddPattern : public OpRewritePattern<neura::FAddOp> {
12+
using OpRewritePattern::OpRewritePattern;
1413

15-
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
16-
auto first = dyn_cast<neura::FAddOp>(op);
17-
if (!first || !first->hasOneUse()) return failure();
14+
LogicalResult matchAndRewrite(neura::FAddOp second,
15+
PatternRewriter &rewriter) const override {
16+
Value lhs = second.getLhs();
17+
Value rhs = second.getRhs();
1818

19-
auto user = dyn_cast<neura::FAddOp>(*first->getUsers().begin());
20-
if (!user) return failure();
19+
auto lhs_op = lhs.getDefiningOp<neura::FAddOp>();
20+
auto rhs_op = rhs.getDefiningOp<neura::FAddOp>();
2121

22-
Location loc = user.getLoc();
23-
Type type = user.getType();
22+
neura::FAddOp first = nullptr;
23+
Value tail;
2424

25-
auto fused = rewriter.create<neura::FAddFAddOp>(loc, type,
26-
first.getLhs(), first.getRhs(), user.getRhs());
25+
// Case 1: LHS is another fadd.
26+
if (lhs_op && second.getRhs()) {
27+
first = lhs_op;
28+
tail = second.getRhs();
29+
}
30+
// Case 2: RHS is another fadd.
31+
else if (rhs_op && second.getLhs()) {
32+
first = rhs_op;
33+
tail = second.getLhs();
34+
}
2735

28-
rewriter.replaceOp(user, fused.getResult());
36+
if (!first)
37+
return failure();
38+
39+
// Only fuses if the first fadd is not reused elsewhere.
40+
if (!first->hasOneUse())
41+
return failure();
42+
43+
Location loc = second.getLoc();
44+
Type type = second.getType();
45+
46+
auto fused = rewriter.create<neura::FAddFAddOp>(
47+
loc, type, first.getLhs(), first.getRhs(), tail);
48+
49+
rewriter.replaceOp(second, fused.getResult());
2950
rewriter.eraseOp(first);
3051
return success();
3152
}
3253
};
3354

34-
struct FusePatternsPass : public PassWrapper<FusePatternsPass, OperationPass<func::FuncOp>> {
55+
struct FuseFMulFAddPattern : public OpRewritePattern<neura::FAddOp> {
56+
using OpRewritePattern::OpRewritePattern;
57+
58+
LogicalResult matchAndRewrite(neura::FAddOp add,
59+
PatternRewriter &rewriter) const override {
60+
auto lhs_op = add.getLhs().getDefiningOp<neura::FMulOp>();
61+
auto rhs_op = add.getRhs().getDefiningOp<neura::FMulOp>();
62+
63+
neura::FMulOp fmul = nullptr;
64+
Value other;
65+
66+
// Case 1: fmul is on the LHS.
67+
if (lhs_op && add.getRhs()) {
68+
fmul = lhs_op;
69+
other = add.getRhs();
70+
}
71+
// Case 2: fmul is on the RHS.
72+
else if (rhs_op && add.getLhs()) {
73+
fmul = rhs_op;
74+
other = add.getLhs();
75+
}
76+
77+
if (!fmul)
78+
return failure();
79+
80+
// Optional: only fuses if fmul has a single use.
81+
if (!fmul->hasOneUse())
82+
return failure();
83+
84+
Location loc = add.getLoc();
85+
Type type = add.getType();
86+
87+
auto fused = rewriter.create<neura::FMulFAddOp>(
88+
loc, type, fmul.getLhs(), fmul.getRhs(), other);
89+
90+
rewriter.replaceOp(add, fused.getResult());
91+
rewriter.eraseOp(fmul);
92+
return success();
93+
}
94+
};
95+
96+
struct FusePatternsPass : public PassWrapper<FusePatternsPass, OperationPass<ModuleOp>> {
3597
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FusePatternsPass)
3698

3799
StringRef getArgument() const override { return "fuse-patterns"; }
38100
StringRef getDescription() const override { return "Apply Neura fusion patterns."; }
39101

40102
void runOnOperation() override {
41103
RewritePatternSet patterns(&getContext());
42-
patterns.add<FuseFAddFAddPattern>(&getContext());
43-
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
44-
signalPassFailure();
104+
patterns.add<FuseFAddFAddPattern>(&getContext(), 2);
105+
patterns.add<FuseFMulFAddPattern>(&getContext(), 3);
106+
FrozenRewritePatternSet frozen(std::move(patterns));
107+
108+
ModuleOp module_op = getOperation();
109+
110+
// Applies to every region inside the module (regardless of func type,
111+
// e.g., mlir func or llvm func).
112+
module_op.walk([&](Operation *op) {
113+
if (!op->getRegions().empty()) {
114+
for (Region &region : op->getRegions()) {
115+
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
116+
signalPassFailure();
117+
}
118+
}
119+
}
120+
});
45121
}
46122
};
47123

lib/Transforms/InsertMovPass.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#include "NeuraDialect/NeuraDialect.h"
22
#include "NeuraDialect/NeuraOps.h"
3+
#include "mlir/Dialect/Func/IR/FuncOps.h"
4+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
35
#include "mlir/IR/PatternMatch.h"
46
#include "mlir/Pass/Pass.h"
57
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
6-
#include "mlir/Dialect/Func/IR/FuncOps.h"
78

89
using namespace mlir;
910

@@ -54,7 +55,7 @@ struct InsertMovForNeuraOps : public RewritePattern {
5455
};
5556

5657
struct InsertMovPass
57-
: public PassWrapper<InsertMovPass, OperationPass<func::FuncOp>> {
58+
: public PassWrapper<InsertMovPass, OperationPass<ModuleOp>> {
5859
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertMovPass)
5960

6061
StringRef getArgument() const override { return "insert-mov"; }
@@ -69,8 +70,21 @@ struct InsertMovPass
6970
void runOnOperation() override {
7071
RewritePatternSet patterns(&getContext());
7172
patterns.add<InsertMovForNeuraOps>(&getContext());
72-
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
73-
signalPassFailure();
73+
FrozenRewritePatternSet frozen(std::move(patterns));
74+
75+
ModuleOp module_op = getOperation();
76+
77+
// Applies to every region inside the module (regardless of func type,
78+
// e.g., mlir func or llvm func).
79+
module_op.walk([&](Operation *op) {
80+
if (!op->getRegions().empty()) {
81+
for (Region &region : op->getRegions()) {
82+
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
83+
signalPassFailure();
84+
}
85+
}
86+
}
87+
});
7488
}
7589
};
7690
} // namespace

test/neura/fadd_fadd.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Applies pattern fusion before mov insertion.
12
// RUN: mlir-neura-opt --lower-arith-to-neura --fuse-patterns --insert-mov %s | FileCheck %s
23

34
func.func @test(%a: f32, %b: f32) -> f32 {

test/neura/for_loop/kernel.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: mlir-neura-opt %s | FileCheck %s
2+
3+
#include <stdio.h>
4+
5+
#define NTAPS 32
6+
7+
float input[NTAPS] = {
8+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
9+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
10+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
11+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
12+
};
13+
float output[NTAPS];
14+
float coefficients[NTAPS] = {0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25,
15+
0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25,
16+
0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25,
17+
0.25, 1.50, 3.75, -2.25, 0.50, 0.75, -3.00, 1.25};
18+
19+
void kernel(float input[], float output[], float coefficient[]);
20+
21+
int main()
22+
{
23+
24+
// input_dsp (input, NTAPS, 0);
25+
26+
kernel(input, output, coefficients);
27+
28+
// output_dsp (input, NTAPS, 0);
29+
// output_dsp (coefficients, NTAPS, 0);
30+
// output_dsp (output, NTAPS, 0);
31+
printf("output: %f\n", output[0]);
32+
return 0;
33+
}
34+
35+
/* input : input sample array */
36+
/* output: output sample array */
37+
/* coefficient: coefficient array */
38+
void kernel(float input[], float output[], float coefficient[]) {
39+
int i;
40+
int j = 0;
41+
42+
for (i = 0; i < NTAPS; ++i) {
43+
float tmp = input[i] * coefficient[i];
44+
output[j] += tmp;
45+
}
46+
}

0 commit comments

Comments
 (0)