Skip to content

Commit 089c6ba

Browse files
authored
Merge pull request #138 from ShangkunLi/support-more-cases
Support more test cases for mapping
2 parents b5f46ef + 683c80b commit 089c6ba

File tree

9 files changed

+225
-99
lines changed

9 files changed

+225
-99
lines changed

include/NeuraDialect/NeuraOps.td

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,26 @@ def Neura_MulOp : Op<NeuraDialect, "mul"> {
4343

4444
def Neura_DivOp : Op<NeuraDialect, "div"> {
4545
let summary = "Integer division operation";
46-
let arguments = (ins AnyType:$lhs, AnyType:$rhs, Optional<AnyType>:$predicate);
46+
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs);
4747
let results = (outs AnyType:$result);
4848
// let assemblyFormat = "$lhs `,` $rhs `,` $predicate attr-dict `:` type($result)";
4949
let traits = [SameOperandsAndResultElementType];
5050
}
5151

52+
def Neura_RemOp : Op<NeuraDialect, "rem">{
53+
let summary = "Integer remainder operation";
54+
let description = [{
55+
Performs an integer remainder operation, computing the result of
56+
a % b, where % is the remainder operator.
57+
58+
Example:
59+
%result = neura.rem %a, %b : i32
60+
}];
61+
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs);
62+
let results = (outs AnyType:$result);
63+
let traits = [SameOperandsAndResultElementType];
64+
}
65+
5266
// Defines a floating-point addition operation.
5367
def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
5468
let summary = "Floating addition operation";
@@ -63,8 +77,8 @@ def Neura_FAddOp : Op<NeuraDialect, "fadd"> {
6377
def Neura_FSubOp: Op<NeuraDialect, "fsub"> {
6478
let summary = "Floating substraction operation";
6579
let opName = "fsub";
66-
let arguments = (ins AnyFloat:$lhs, AnyFloat:$rhs, Optional<AnyType>:$predicate);
67-
let results = (outs AnyFloat:$result);
80+
let arguments = (ins AnyType:$lhs, Optional<AnyType>:$rhs);
81+
let results = (outs AnyType:$result);
6882
// let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
6983
let traits = [SameOperandsAndResultElementType];
7084
}
@@ -222,9 +236,9 @@ def Neura_AllocaOp : Op<NeuraDialect, "alloca"> {
222236
%ptr = neura.alloca %size : !neura.data<i32, i1> -> !llvm.ptr
223237
}];
224238

225-
let arguments = (ins AnyType:$size);
239+
let arguments = (ins Optional<AnyType>:$size);
226240
let results = (outs AnyType:$result);
227-
let assemblyFormat = "$size attr-dict `:` type($size) `->` type($result)";
241+
let assemblyFormat = "($size^ `:` type($size))? attr-dict `->` type($result)";
228242
}
229243

230244
// Defines a sign extension operation.
@@ -259,6 +273,20 @@ def Neura_ZExtOp : Op<NeuraDialect, "zext"> {
259273
let assemblyFormat = "$value attr-dict `:` type($value) `->` type($result)";
260274
}
261275

276+
// Defines a logical shift left operation.
277+
def Neura_ShlOp : Op<NeuraDialect, "shl"> {
278+
let summary = "Logical shift left operation";
279+
let description = [{
280+
Performs a logical left shift on an integer value.
281+
Similar to llvm.shl, but works with predicated values.
282+
283+
Example:
284+
%shifted = neura.shl %value, %shiftAmount : !neura.data<i32, i1> -> !neura.data<i32, i1>
285+
}];
286+
let arguments = (ins AnyType:$value, Optional<AnyType>:$shiftAmount);
287+
let results = (outs AnyType:$result);
288+
}
289+
262290
// ----------------------------------------------------
263291
// Defines vector operations.
264292

lib/Conversion/ArithToNeura/ArithToNeuraPass.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ struct ArithSubFToNeuraFSub : public OpRewritePattern<mlir::arith::SubFOp> {
9898
Type result_type = op.getType();
9999

100100
// Optional predicate: default to null.
101-
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
102-
nullptr);
101+
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs);
103102
return success();
104103
}
105104
};
@@ -144,8 +143,7 @@ struct ArithDivSIToNeuraDiv : public OpRewritePattern<mlir::arith::DivSIOp> {
144143
Type result_type = op.getType();
145144
// Converts arith DivSIOp to Neura DivOp.
146145
// Optional predicate: default to null.
147-
rewriter.replaceOpWithNewOp<neura::DivOp>(op, result_type, lhs, rhs,
148-
nullptr);
146+
rewriter.replaceOpWithNewOp<neura::DivOp>(op, result_type, lhs, rhs);
149147
return success();
150148
}
151149
};
@@ -176,8 +174,7 @@ struct ArithRemSIToNeuraOp : public OpRewritePattern<mlir::arith::RemSIOp> {
176174
Location loc = op.getLoc();
177175
// Converts arith RemSIOp to basic Neura Op.
178176
// Optional predicate: default to null.
179-
Value div =
180-
rewriter.create<neura::DivOp>(loc, result_type, lhs, rhs, nullptr);
177+
Value div = rewriter.create<neura::DivOp>(loc, result_type, lhs, rhs);
181178
Value mul = rewriter.create<neura::MulOp>(loc, result_type, rhs, div);
182179
Value rem = rewriter.create<neura::SubOp>(loc, result_type, lhs, mul);
183180

lib/Conversion/LlvmToNeura/LlvmToNeuraPass.cpp

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ struct LlvmFSubToNeuraFSub : public OpRewritePattern<mlir::LLVM::FSubOp> {
7373
}
7474

7575
// Sets optional predicate: default to 'none'.
76-
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs,
77-
Value());
76+
rewriter.replaceOpWithNewOp<neura::FSubOp>(op, result_type, lhs, rhs);
7877
return success();
7978
}
8079
};
@@ -109,6 +108,35 @@ struct LlvmFMulToNeuraFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
109108
}
110109
};
111110

111+
struct LlvmSDivToNeuraDiv : public OpRewritePattern<LLVM::SDivOp> {
112+
using OpRewritePattern::OpRewritePattern;
113+
114+
LogicalResult matchAndRewrite(LLVM::SDivOp op,
115+
PatternRewriter &rewriter) const override {
116+
Value lhs = op.getLhs();
117+
Value rhs = op.getRhs();
118+
Type resultType = op.getType();
119+
120+
rewriter.replaceOpWithNewOp<neura::DivOp>(op, resultType, lhs, rhs);
121+
return success();
122+
}
123+
};
124+
125+
struct LlvmSRemToNeuraRem : public OpRewritePattern<LLVM::SRemOp> {
126+
using OpRewritePattern<LLVM::SRemOp>::OpRewritePattern;
127+
128+
LogicalResult matchAndRewrite(LLVM::SRemOp op,
129+
PatternRewriter &rewriter) const override {
130+
Value lhs = op.getLhs();
131+
Value rhs = op.getRhs();
132+
Type resultType = op.getType();
133+
134+
// Create neura.rem operation to replace llvm.srem
135+
rewriter.replaceOpWithNewOp<neura::RemOp>(op, resultType, lhs, rhs);
136+
return success();
137+
}
138+
};
139+
112140
struct LlvmVFMulToNeuraVFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
113141
using OpRewritePattern::OpRewritePattern;
114142

@@ -311,11 +339,11 @@ struct LlvmAllocaToNeuraAlloca : public OpRewritePattern<LLVM::AllocaOp> {
311339
PatternRewriter &rewriter) const override {
312340
Value size = op.getArraySize();
313341
Type resultType = op.getType();
314-
342+
315343
// Converts the size to neura.data<i32, i1> if it's not already.
316344
// Assumes the size is already in the right format.
317345
// Handles type conversion here.
318-
346+
319347
rewriter.replaceOpWithNewOp<neura::AllocaOp>(op, resultType, size);
320348
return success();
321349
}
@@ -328,7 +356,7 @@ struct LlvmSExtToNeuraSExt : public OpRewritePattern<LLVM::SExtOp> {
328356
PatternRewriter &rewriter) const override {
329357
Value input = op.getArg();
330358
Type resultType = op.getType();
331-
359+
332360
rewriter.replaceOpWithNewOp<neura::SExtOp>(op, resultType, input);
333361
return success();
334362
}
@@ -341,7 +369,7 @@ struct LlvmZExtToNeuraZExt : public OpRewritePattern<LLVM::ZExtOp> {
341369
PatternRewriter &rewriter) const override {
342370
Value input = op.getArg();
343371
Type resultType = op.getType();
344-
372+
345373
rewriter.replaceOpWithNewOp<neura::ZExtOp>(op, resultType, input);
346374
return success();
347375
}
@@ -355,36 +383,48 @@ struct LlvmMulToNeuraMul : public OpRewritePattern<LLVM::MulOp> {
355383
Value lhs = op.getLhs();
356384
Value rhs = op.getRhs();
357385
Type resultType = op.getType();
358-
386+
359387
rewriter.replaceOpWithNewOp<neura::MulOp>(op, resultType, lhs, rhs);
360388
return success();
361389
}
362390
};
363391

392+
struct LlvmShlToNeuraShl : public OpRewritePattern<LLVM::ShlOp> {
393+
using OpRewritePattern::OpRewritePattern;
394+
395+
LogicalResult matchAndRewrite(LLVM::ShlOp op,
396+
PatternRewriter &rewriter) const override {
397+
Value lhs = op.getLhs();
398+
Value rhs = op.getRhs();
399+
Type resultType = op.getType();
400+
401+
rewriter.replaceOpWithNewOp<neura::ShlOp>(op, resultType, lhs, rhs);
402+
return success();
403+
}
404+
};
405+
364406
struct LlvmFuncToNeuraFunc : public OpRewritePattern<LLVM::LLVMFuncOp> {
365407
using OpRewritePattern::OpRewritePattern;
366408

367409
LogicalResult matchAndRewrite(LLVM::LLVMFuncOp op,
368410
PatternRewriter &rewriter) const override {
369411

370-
371412
auto target = op->getAttrOfType<StringAttr>(mlir::accel::kAcceleratorAttr);
372413
if (!target || target.getValue() != mlir::accel::kNeuraTarget) {
373414
return failure();
374415
}
375416

376417
// Converts LLVMFunctionType to FunctionType.
377418
auto llvmFuncType = op.getFunctionType();
378-
auto funcType = rewriter.getFunctionType(
379-
llvmFuncType.getParams(),
380-
llvmFuncType.getReturnType()
381-
);
419+
auto funcType = rewriter.getFunctionType(llvmFuncType.getParams(),
420+
llvmFuncType.getReturnType());
382421

383-
// Creates the new func.func operation using OperationState to have full control.
422+
// Creates the new func.func operation using OperationState to have full
423+
// control.
384424
OperationState state(op.getLoc(), func::FuncOp::getOperationName());
385425
state.addAttribute("sym_name", rewriter.getStringAttr(op.getName()));
386426
state.addAttribute("function_type", TypeAttr::get(funcType));
387-
427+
388428
// Copies ALL attributes from the original llvm.func exactly as they are.
389429
// Skips function type and name attributes as they are handled separately.
390430
SmallVector<NamedAttribute> attrs;
@@ -395,15 +435,16 @@ struct LlvmFuncToNeuraFunc : public OpRewritePattern<LLVM::LLVMFuncOp> {
395435
attrs.push_back(attr);
396436
}
397437
state.addAttributes(attrs);
398-
438+
399439
// Adds the function body region.
400440
state.addRegion();
401-
441+
402442
auto newFunc = cast<func::FuncOp>(rewriter.create(state));
403443

404444
// Moves the function body.
405-
rewriter.inlineRegionBefore(op.getBody(), newFunc.getBody(), newFunc.getBody().end());
406-
445+
rewriter.inlineRegionBefore(op.getBody(), newFunc.getBody(),
446+
newFunc.getBody().end());
447+
407448
// Replaces the old function.
408449
rewriter.replaceOp(op, newFunc);
409450
return success();
@@ -435,20 +476,19 @@ struct LlvmCallToFuncCall : public OpRewritePattern<LLVM::CallOp> {
435476

436477
// Gets the result types from the function signature.
437478
auto resultTypes = funcOp.getFunctionType().getResults();
438-
479+
439480
// Converts the call to func.call.
440481
auto newCall = rewriter.create<func::CallOp>(
441-
op.getLoc(), resultTypes, callee.value(), op.getArgOperands()
442-
);
443-
482+
op.getLoc(), resultTypes, callee.value(), op.getArgOperands());
483+
444484
// Replaces the old call with the new one.
445485
// Handles both cases: calls with results and calls without results.
446486
if (op.getNumResults() == 0) {
447487
rewriter.eraseOp(op);
448488
} else {
449489
rewriter.replaceOp(op, newCall->getResults());
450490
}
451-
491+
452492
return success();
453493
}
454494
};
@@ -494,6 +534,9 @@ struct LowerLlvmToNeuraPass
494534
patterns.add<LlvmMulToNeuraMul>(&getContext());
495535
patterns.add<LlvmFuncToNeuraFunc>(&getContext());
496536
patterns.add<LlvmCallToFuncCall>(&getContext());
537+
patterns.add<LlvmShlToNeuraShl>(&getContext());
538+
patterns.add<LlvmSDivToNeuraDiv>(&getContext());
539+
patterns.add<LlvmSRemToNeuraRem>(&getContext());
497540

498541
FrozenRewritePatternSet frozen(std::move(patterns));
499542

lib/Conversion/MemRefToNeura/MemRefToNeuraPass.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mlir/IR/PatternMatch.h"
99
#include "mlir/Pass/Pass.h"
1010
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
11+
#include "llvm/Support/LogicalResult.h"
1112

1213
using namespace mlir;
1314
using namespace mlir::neura;
@@ -46,6 +47,30 @@ struct MemRefStoreLowering : public OpRewritePattern<memref::StoreOp> {
4647
}
4748
};
4849

50+
struct MemRefAllocaToNeuraAlloca : public OpRewritePattern<memref::AllocaOp> {
51+
using OpRewritePattern<memref::AllocaOp>::OpRewritePattern;
52+
53+
LogicalResult matchAndRewrite(memref::AllocaOp alloca_op,
54+
PatternRewriter &rewriter) const override {
55+
// Gets the result type.
56+
Type result_type = alloca_op.getType();
57+
58+
// Checks if we have dynamic dimensions.
59+
if (!alloca_op.getDynamicSizes().empty()) {
60+
// For dynamic dimensions, we need to create the alloca with the size
61+
// arguments.
62+
rewriter.replaceOpWithNewOp<neura::AllocaOp>(alloca_op, result_type,
63+
alloca_op.getDynamicSizes());
64+
} else {
65+
// For static dimensions, we can create the alloca without size arguments.
66+
rewriter.replaceOpWithNewOp<neura::AllocaOp>(alloca_op, result_type,
67+
Value());
68+
}
69+
70+
return success();
71+
}
72+
};
73+
4974
struct LowerMemRefToNeuraPass
5075
: public PassWrapper<LowerMemRefToNeuraPass, OperationPass<ModuleOp>> {
5176

@@ -64,8 +89,11 @@ struct LowerMemRefToNeuraPass
6489
ModuleOp module_op = getOperation();
6590
MLIRContext *context = &getContext();
6691
RewritePatternSet patterns(&getContext());
92+
6793
patterns.add<MemRefLoadLowering>(context);
6894
patterns.add<MemRefStoreLowering>(context);
95+
patterns.add<MemRefAllocaToNeuraAlloca>(context);
96+
6997
module_op.walk([&](func::FuncOp func_op) {
7098
if (func_op->hasAttr(mlir::accel::kAcceleratorAttr)) {
7199
auto target =

lib/NeuraDialect/Architecture/Architecture.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ Architecture::Architecture(int width, int height) {
209209
for (int x = 0; x < width; ++x) {
210210
// Gets the tile by coordinates.
211211
Tile *tile = getTile(x, y);
212-
const int kNUM_REGS_PER_REGFILE = 4;
213-
const int kNUM_REGFILES_PER_CLUSTER = 2;
212+
const int kNUM_REGS_PER_REGFILE = 8;
213+
const int kNUM_REGFILES_PER_CLUSTER = 4;
214214

215215
// Assembles register files into a cluster.
216216
RegisterFileCluster *register_file_cluster =

lib/NeuraDialect/Transforms/MapToAcceleratorPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ struct MapToAcceleratorPass
173173
int res_mii = calculateResMii(func, architecture);
174174

175175
const int possibleMinII = std::max(rec_mii, res_mii);
176-
constexpr int maxII = 15;
176+
constexpr int maxII = 20;
177177
std::vector<Operation *> topologically_sorted_ops =
178178
getTopologicallySortedOps(func);
179179
if (topologically_sorted_ops.empty()) {

0 commit comments

Comments
 (0)