Skip to content

Commit ff914bc

Browse files
authored
Merge pull request #17 from coredac/llvm_for_lower
Tag kernel with accelerator attr for optional lowering
2 parents a167370 + 0f9d6e3 commit ff914bc

File tree

13 files changed

+301
-44
lines changed

13 files changed

+301
-44
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ jobs:
6464
- name: setup dataflow tool-chain
6565
working-directory: ${{github.workspace}}
6666
run: |
67-
git clone https://github.com/coredac/dataflow.git
68-
cd dataflow
6967
mkdir build && cd build
7068
cmake -G Ninja .. \
7169
-DLLVM_DIR=${{github.workspace}}/llvm-project/build/lib/cmake/llvm \
@@ -102,6 +100,6 @@ jobs:
102100
- name: run test
103101
working-directory: ${{github.workspace}}
104102
run: |
105-
cd ${{github.workspace}}/dataflow/test
103+
cd ${{github.workspace}}/test
106104
${{github.workspace}}/llvm-project/build/bin/llvm-lit * -v
107105

include/Common/AcceleratorAttrs.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef COMMON_ACCELERATOR_ATTRS_H
2+
#define COMMON_ACCELERATOR_ATTRS_H
3+
4+
#include "llvm/ADT/StringRef.h"
5+
6+
namespace mlir {
7+
namespace accel {
8+
9+
// Common attribute key.
10+
constexpr llvm::StringRef kAcceleratorAttr = "accelerator";
11+
12+
// Common accelerator targets.
13+
constexpr llvm::StringRef kNeuraTarget = "neura";
14+
constexpr llvm::StringRef kGpuTarget = "gpu";
15+
constexpr llvm::StringRef kTpuTarget = "tpu";
16+
17+
} // namespace accel
18+
} // namespace mlir
19+
20+
#endif // COMMON_ACCELERATOR_ATTRS_H

include/NeuraDialect/NeuraOps.td

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
include "NeuraDialect/NeuraDialect.td"
44

5+
// ----------------------------------------------------
6+
// Defines basic scalar operations.
7+
8+
def Neura_ConstantOp : Op<NeuraDialect, "constant"> {
9+
let arguments = (ins AnyAttr:$value);
10+
let results = (outs AnyType:$result);
11+
// let assemblyFormat = "attr-dict `:` type($result)";
12+
}
13+
514
// Defines an addition operation.
615
def Neura_AddOp : Op<NeuraDialect, "add"> {
716
let summary = "Integer addition operation";
@@ -12,7 +21,7 @@ def Neura_AddOp : Op<NeuraDialect, "add"> {
1221
let traits = [SameOperandsAndResultElementType];
1322
}
1423

15-
// Defines an addition operation.
24+
// Defines a floating-point addition operation.
1625
def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
1726
let summary = "Floating addition operation";
1827
let opName = "fadd";
@@ -22,7 +31,7 @@ def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
2231
let traits = [SameOperandsAndResultElementType];
2332
}
2433

25-
// Defines a multiplication operation.
34+
// Defines a floating-point multiplication operation.
2635
def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
2736
let summary = "Floating multiplication operation";
2837
let opName = "fmul";
@@ -32,6 +41,69 @@ def Neura_FMulOp : Op<NeuraDialect, "fmul"> {
3241
let traits = [SameOperandsAndResultElementType];
3342
}
3443

44+
def Neura_OrOp : Op<NeuraDialect, "or"> {
45+
let summary = "Bitwise OR operation";
46+
let arguments = (ins AnySignlessInteger:$lhs, AnySignlessInteger:$rhs);
47+
let results = (outs AnySignlessInteger:$result);
48+
let traits = [SameOperandsAndResultElementType];
49+
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
50+
}
51+
52+
// Defines a move operation for data communication.
53+
def Neura_MovOp : Op<NeuraDialect, "mov"> {
54+
let summary = "Move operation";
55+
let opName = "mov";
56+
let arguments = (ins AnyType:$lhs);
57+
let results = (outs AnyType:$result);
58+
// let assemblyFormat = "$lhs attr-dict `:` type($lhs) `->` type($result)";
59+
// let traits = [Pure];
60+
}
61+
62+
def Neura_ICmpOp : Op<NeuraDialect, "icmp"> {
63+
let summary = "Integer compare operation";
64+
let opName = "icmp";
65+
let arguments = (ins AnyInteger:$lhs, AnyInteger:$rhs,
66+
StrAttr:$predicate);
67+
let results = (outs I1:$result);
68+
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
69+
// let traits = [SameOperandsAndResultElementType];
70+
}
71+
72+
def Neura_LoadOp : Op<NeuraDialect, "load"> {
73+
let arguments = (ins AnyType:$addr);
74+
let results = (outs AnyType:$value);
75+
// let assemblyFormat = "$addr attr-dict `:` type($value)";
76+
}
77+
78+
def Neura_StoreOp : Op<NeuraDialect, "store"> {
79+
let arguments = (ins AnyType:$value, AnyType:$addr);
80+
let results = (outs);
81+
// let assemblyFormat = "$value `,` $addr attr-dict";
82+
}
83+
84+
def Neura_GEP : Op<NeuraDialect, "gep"> {
85+
let summary = "Pointer computation using offset indices";
86+
let arguments = (ins AnyType:$base, Variadic<AnyInteger>:$indices);
87+
let results = (outs AnyType:$result);
88+
// let assemblyFormat = "$base `[` $indices `]` attr-dict";
89+
}
90+
91+
def Neura_CondBr : Op<NeuraDialect, "cond_br", [Terminator, AttrSizedOperandSegments]> {
92+
let arguments = (ins I1:$condition,
93+
Variadic<AnyType>:$trueArgs,
94+
Variadic<AnyType>:$falseArgs);
95+
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
96+
let assemblyFormat = "$condition `then` $trueArgs `:` type($trueArgs) `to` $trueDest `else` $falseArgs `:` type($falseArgs) `to` $falseDest attr-dict";
97+
}
98+
99+
def Neura_ReturnOp : Op<NeuraDialect, "return", [Terminator]> {
100+
let arguments = (ins Variadic<AnyType>:$values);
101+
// let assemblyFormat = "($values^)? attr-dict";
102+
}
103+
104+
// ----------------------------------------------------
105+
// Defines vector operations.
106+
35107
def VectorOfAnyFloat :
36108
TypeConstraint<
37109
CPred<
@@ -51,6 +123,9 @@ def Neura_VFMulOp : Op<NeuraDialect, "vfmul"> {
51123
let traits = [SameOperandsAndResultElementType];
52124
}
53125

126+
// ----------------------------------------------------
127+
// Defines fused operations.
128+
54129
def Neura_FAddFAddOp : Op<NeuraDialect, "fadd_fadd"> {
55130
let summary = "Fused fadd(fadd(a, b), c)";
56131
let arguments = (ins AnyFloat:$a, AnyFloat:$b, AnyFloat:$c);
@@ -67,12 +142,4 @@ def Neura_FMulFAddOp : Op<NeuraDialect, "fmul_fadd"> {
67142
let traits = [SameOperandsAndResultElementType];
68143
}
69144

70-
// Defines a move operation for data communication.
71-
def Neura_MovOp : Op<NeuraDialect, "mov"> {
72-
let summary = "Move operation";
73-
let opName = "mov";
74-
let arguments = (ins AnyType:$lhs);
75-
let results = (outs AnyType:$result);
76-
let assemblyFormat = "$lhs attr-dict `:` type($lhs) `->` type($result)";
77-
// let traits = [Pure];
78-
}
145+
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef NEURA_TRANSFORMS_ASSIGN_ACCELERATORPASS_H
2+
#define NEURA_TRANSFORMS_ASSIGN_ACCELERATORPASS_H
3+
4+
#include "mlir/Pass/Pass.h"
5+
6+
namespace mlir {
7+
namespace neura {
8+
std::unique_ptr<mlir::Pass> createAssignAcceleratorPass();
9+
} // namespace neura
10+
} // namespace mlir
11+
12+
#endif // NEURA_TRANSFORMS_ASSIGN_ACCELERATORPASS_H
13+

lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Conversion/LlvmToNeura/LlvmToNeura.h"
2+
#include "Common/AcceleratorAttrs.h"
23
#include "NeuraDialect/NeuraDialect.h"
34
#include "NeuraDialect/NeuraOps.h"
45
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
@@ -75,6 +76,110 @@ struct LlvmVFMulToNeuraVFMul: public OpRewritePattern<mlir::LLVM::FMulOp> {
7576
}
7677
};
7778

79+
struct LlvmICmpToNeuraICmp : public OpRewritePattern<LLVM::ICmpOp> {
80+
using OpRewritePattern::OpRewritePattern;
81+
82+
LogicalResult matchAndRewrite(LLVM::ICmpOp op,
83+
PatternRewriter &rewriter) const override {
84+
auto pred = op.getPredicate();
85+
auto lhs = op.getLhs();
86+
auto rhs = op.getRhs();
87+
auto resultType = op.getType();
88+
89+
rewriter.replaceOpWithNewOp<neura::ICmpOp>(
90+
op, resultType, lhs, rhs, rewriter.getStringAttr(LLVM::stringifyICmpPredicate(pred)));
91+
return success();
92+
}
93+
};
94+
95+
struct LlvmGEPToNeuraGEP : public OpRewritePattern<mlir::LLVM::GEPOp> {
96+
using OpRewritePattern::OpRewritePattern;
97+
98+
LogicalResult matchAndRewrite(mlir::LLVM::GEPOp op,
99+
PatternRewriter &rewriter) const override {
100+
Value base = op.getBase();
101+
SmallVector<Value> indexValues;
102+
103+
for (auto gepIndex : op.getIndices()) {
104+
if (auto val = gepIndex.dyn_cast<Value>()) {
105+
indexValues.push_back(val);
106+
} else if (auto intAttr = gepIndex.dyn_cast<IntegerAttr>()) {
107+
auto cst = rewriter.create<neura::ConstantOp>(
108+
op.getLoc(), rewriter.getIndexType(), intAttr);
109+
indexValues.push_back(cst);
110+
} else {
111+
return op.emitOpError("Unsupported GEP index kind");
112+
}
113+
}
114+
115+
rewriter.replaceOpWithNewOp<neura::GEP>(op, op.getType(), base, indexValues);
116+
return success();
117+
}
118+
};
119+
120+
struct LlvmLoadToNeuraLoad : public OpRewritePattern<mlir::LLVM::LoadOp> {
121+
using OpRewritePattern::OpRewritePattern;
122+
123+
LogicalResult matchAndRewrite(mlir::LLVM::LoadOp op,
124+
PatternRewriter &rewriter) const override {
125+
Value ptr = op.getAddr(); // getPointer() is deprecated
126+
Type resultType = op.getResult().getType();
127+
rewriter.replaceOpWithNewOp<neura::LoadOp>(op, resultType, ptr);
128+
return success();
129+
}
130+
};
131+
132+
struct LlvmStoreToNeuraStore : public OpRewritePattern<mlir::LLVM::StoreOp> {
133+
using OpRewritePattern::OpRewritePattern;
134+
135+
LogicalResult matchAndRewrite(mlir::LLVM::StoreOp op,
136+
PatternRewriter &rewriter) const override {
137+
Value value = op.getValue();
138+
Value addr = op.getAddr(); // getPointer() is deprecated
139+
rewriter.replaceOpWithNewOp<neura::StoreOp>(op, value, addr);
140+
return success();
141+
}
142+
};
143+
144+
struct LlvmCondBrToNeuraCondBr : public OpRewritePattern<LLVM::CondBrOp> {
145+
using OpRewritePattern::OpRewritePattern;
146+
LogicalResult matchAndRewrite(LLVM::CondBrOp op,
147+
PatternRewriter &rewriter) const override {
148+
// Get the source operation's successors (basic blocks)
149+
Block *trueDest = op.getTrueDest();
150+
Block *falseDest = op.getFalseDest();
151+
152+
// Get the operands for each destination
153+
ValueRange trueOperands = op.getTrueDestOperands();
154+
ValueRange falseOperands = op.getFalseDestOperands();
155+
156+
// Create the new operation with proper successors
157+
auto newOp = rewriter.create<neura::CondBr>(
158+
op.getLoc(), // Location
159+
op.getCondition(), // Condition
160+
trueOperands, // True destination operands
161+
falseOperands, // False destination operands
162+
trueDest, // True destination block
163+
falseDest // False destination block
164+
);
165+
166+
// Replace the old op with the new one
167+
rewriter.replaceOp(op, newOp->getResults());
168+
169+
return success();
170+
}
171+
};
172+
173+
struct LlvmReturnToNeuraReturn : public OpRewritePattern<LLVM::ReturnOp> {
174+
using OpRewritePattern::OpRewritePattern;
175+
176+
LogicalResult matchAndRewrite(LLVM::ReturnOp op,
177+
PatternRewriter &rewriter) const override {
178+
rewriter.replaceOpWithNewOp<neura::ReturnOp>(op, op.getOperands());
179+
return success();
180+
}
181+
};
182+
78183
struct LowerLlvmToNeuraPass
79184
: public PassWrapper<LowerLlvmToNeuraPass, OperationPass<ModuleOp>> {
80185

@@ -96,17 +201,27 @@ struct LowerLlvmToNeuraPass
96201
patterns.add<LlvmAddToNeuraAdd>(&getContext());
97202
patterns.add<LlvmFMulToNeuraFMul>(&getContext());
98203
patterns.add<LlvmVFMulToNeuraVFMul>(&getContext());
204+
patterns.add<LlvmICmpToNeuraICmp>(&getContext());
205+
patterns.add<LlvmGEPToNeuraGEP>(&getContext());
206+
patterns.add<LlvmLoadToNeuraLoad>(&getContext());
207+
patterns.add<LlvmStoreToNeuraStore>(&getContext());
208+
patterns.add<LlvmCondBrToNeuraCondBr>(&getContext());
209+
patterns.add<LlvmReturnToNeuraReturn>(&getContext());
210+
99211
FrozenRewritePatternSet frozen(std::move(patterns));
100212

101213
ModuleOp module_op = getOperation();
102214

103215
// Applies to every region inside the module (regardless of func type,
104216
// e.g., mlir func or llvm func).
105-
module_op.walk([&](Operation *op) {
106-
if (!op->getRegions().empty()) {
107-
for (Region &region : op->getRegions()) {
108-
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
109-
signalPassFailure();
217+
module_op.walk([&](FunctionOpInterface func) {
218+
if (func->hasAttr(mlir::accel::kAcceleratorAttr)) {
219+
auto target = func->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
220+
if (target && target.getValue() == mlir::accel::kNeuraTarget) {
221+
for (Region &region : func->getRegions()) {
222+
if (failed(applyPatternsAndFoldGreedily(region, frozen))) {
223+
signalPassFailure();
224+
}
110225
}
111226
}
112227
}

lib/Conversion/LlvmToNeura/LlvmToNeuraPatterns.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,13 @@ def : Pat<
88
(Neura_FAddOp $lhs, $rhs)
99
>;
1010

11+
def : Pat<
12+
(LLVM_ConstantOp $value),
13+
(Neura_ConstantOp $value)
14+
>;
15+
16+
def : Pat<
17+
(LLVM_OrOp $lhs, $rhs),
18+
(Neura_OrOp $lhs, $rhs)
19+
>;
20+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include "Common/AcceleratorAttrs.h"
2+
#include "mlir/IR/Builders.h"
3+
#include "mlir/IR/BuiltinOps.h"
4+
#include "mlir/Dialect/Func/IR/FuncOps.h"
5+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
6+
#include "mlir/Pass/Pass.h"
7+
8+
using namespace mlir;
9+
10+
namespace {
11+
struct AssignAcceleratorPass : public PassWrapper<AssignAcceleratorPass, OperationPass<ModuleOp>> {
12+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AssignAcceleratorPass)
13+
14+
StringRef getArgument() const override { return "assign-accelerator"; }
15+
StringRef getDescription() const override { return "Tags non-main functions as neura.kernel."; }
16+
17+
void runOnOperation() override {
18+
ModuleOp module = getOperation();
19+
Builder builder(&getContext());
20+
21+
module.walk([&](Operation *op) {
22+
if (auto func = dyn_cast<FunctionOpInterface>(op)) {
23+
if (func.getName() != "main" &&
24+
!func.isExternal() &&
25+
!func->hasAttr(mlir::accel::kAcceleratorAttr)) {
26+
func->setAttr(mlir::accel::kAcceleratorAttr, builder.getStringAttr(mlir::accel::kNeuraTarget));
27+
}
28+
}
29+
});
30+
}
31+
};
32+
} // namespace
33+
34+
/// Register the pass
35+
namespace mlir {
36+
namespace neura {
37+
std::unique_ptr<Pass> createAssignAcceleratorPass() {
38+
return std::make_unique<AssignAcceleratorPass>();
39+
}
40+
} // namespace neura
41+
} // namespace mlir
42+

0 commit comments

Comments
 (0)