Skip to content

Commit 64d2833

Browse files
authored
Merge pull request #91 from ShangkunLi/removecast
Remove non-sense neura::cast ops
2 parents 1b0eed3 + 33e89bd commit 64d2833

File tree

7 files changed

+228
-0
lines changed

7 files changed

+228
-0
lines changed

include/NeuraDialect/NeuraPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ std::unique_ptr<mlir::Pass> createMapToAcceleratorPass();
2828
std::unique_ptr<mlir::Pass> createGenerateCodePass();
2929
std::unique_ptr<mlir::Pass> createFuseControlFlowPass();
3030
std::unique_ptr<mlir::Pass> createCanonicalizeLiveInPass();
31+
std::unique_ptr<mlir::Pass> createCanonicalizeCastPass();
3132

3233
#define GEN_PASS_REGISTRATION
3334
#include "NeuraDialect/NeuraPasses.h.inc"

include/NeuraDialect/NeuraPasses.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,15 @@ def CanonicalizeLiveIn : Pass<"canonicalize-live-in", "ModuleOp"> {
8686
let constructor = "neura::createCanonicalizeLiveInPass()";
8787
}
8888

89+
def CanonicalizeCast : Pass<"canonicalize-cast", "ModuleOp"> {
90+
let summary = "Canonicalizes cast operations in the Neura dialect";
91+
let description = [{
92+
This pass applies canonicalization transformations to neura::cast operations.
93+
The canonicalization includes:
94+
1. Removing redundant casts.
95+
2. Converting index (i64) types to i64 (index).
96+
}];
97+
let constructor = "neura::createCanonicalizeCastPass()";
98+
}
99+
89100
#endif // NEURA_PASSES_TD

lib/NeuraDialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_library(
1212
GenerateCodePass.cpp
1313
FuseControlFlowPass.cpp
1414
CanonicalizeLiveInPass.cpp
15+
CanonicalizeCastPass.cpp
1516

1617
DEPENDS
1718
MLIRNeuraTransformsIncGen
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#include "NeuraDialect/NeuraOps.h"
2+
#include "mlir/Dialect/Func/IR/FuncOps.h"
3+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
4+
#include "mlir/IR/Block.h"
5+
#include "mlir/IR/BuiltinAttributes.h"
6+
#include "mlir/IR/Operation.h"
7+
#include "mlir/IR/Region.h"
8+
#include "mlir/IR/Value.h"
9+
#include "mlir/Pass/Pass.h"
10+
#include "mlir/Support/LLVM.h"
11+
12+
using namespace mlir;
13+
14+
namespace {
15+
16+
LogicalResult canonicalizeCast(Region &region) {
17+
// Handles block arguments.
18+
for (Block &block : region.getBlocks()) {
19+
for (BlockArgument arg : block.getArguments()) {
20+
if (arg.getType().isIndex()) {
21+
// Replaces index type with i64.
22+
arg.setType(IntegerType::get(arg.getContext(), 64));
23+
}
24+
}
25+
}
26+
27+
region.walk([&](Operation *op) {
28+
// Handles the value attributes in neura::ConstantOp.
29+
if (isa<neura::ConstantOp>(op)) {
30+
Attribute value_attr = op->getAttr("value");
31+
if (!value_attr) {
32+
return;
33+
}
34+
if (IntegerAttr int_attr = dyn_cast<IntegerAttr>(value_attr)) {
35+
if (isa<IntegerType>(op->getResult(0).getType())) {
36+
return;
37+
}
38+
if (isa<IndexType>(op->getResult(0).getType())) {
39+
IntegerAttr new_attr = IntegerAttr::get(
40+
IntegerType::get(op->getContext(), 64), int_attr.getInt());
41+
op->setAttr("value", new_attr);
42+
}
43+
}
44+
}
45+
46+
// Replaces all index types with i64.
47+
for (OpResult result : op->getOpResults()) {
48+
auto type = result.getType();
49+
if (isa<IndexType>(type)) {
50+
result.setType(mlir::IntegerType::get(op->getContext(), 64));
51+
}
52+
}
53+
54+
if (neura::CastOp cast_op = dyn_cast<neura::CastOp>(op)) {
55+
StringAttr cast_type_attr =
56+
cast_op->getAttrOfType<StringAttr>("cast_type");
57+
if (!cast_type_attr)
58+
return;
59+
StringRef cast_type = cast_type_attr.getValue();
60+
61+
Type src_type = cast_op->getOperand(0).getType();
62+
Type dst_type = cast_op->getResult(0).getType();
63+
64+
// Reomoves the index->i64 or i64->index cast operations.
65+
if ((cast_type == "index_to_int" && isa<IntegerType>(src_type) &&
66+
isa<IntegerType>(dst_type) &&
67+
dyn_cast<IntegerType>(src_type).getWidth() == 64 &&
68+
dyn_cast<IntegerType>(dst_type).getWidth() == 64) ||
69+
(cast_type == "int_to_index" && isa<IntegerType>(src_type) &&
70+
isa<IntegerType>(dst_type) &&
71+
dyn_cast<IntegerType>(src_type).getWidth() == 64 &&
72+
dyn_cast<IntegerType>(dst_type).getWidth() == 64)) {
73+
cast_op->getResult(0).replaceAllUsesWith(cast_op->getOperand(0));
74+
cast_op->erase();
75+
return;
76+
}
77+
78+
// Changes index->i32 or i32->index casts to i64->i32 or i32->i64.
79+
if (cast_type == "index_to_int" && isa<IntegerType>(dst_type) &&
80+
dyn_cast<IntegerType>(dst_type).getWidth() == 32) {
81+
cast_op->setAttr("cast_type",
82+
StringAttr::get(op->getContext(), "i64_to_i32"));
83+
return;
84+
}
85+
if (cast_type == "int_to_index" && isa<IntegerType>(src_type) &&
86+
dyn_cast<IntegerType>(src_type).getWidth() == 32) {
87+
cast_op->setAttr("cast_type",
88+
StringAttr::get(op->getContext(), "i32_to_i64"));
89+
return;
90+
}
91+
// TODO: Handles other cast types if needed.
92+
}
93+
});
94+
return success();
95+
}
96+
97+
struct CanonicalizeCastPass
98+
: public PassWrapper<CanonicalizeCastPass, OperationPass<ModuleOp>> {
99+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeCastPass)
100+
StringRef getArgument() const override { return "canonicalize-cast"; }
101+
StringRef getDescription() const override {
102+
return "Canonicalizes cast operations in the Neura dialect, specifically "
103+
"removing unnecessary index to i64 casts and vice versa.";
104+
}
105+
106+
void runOnOperation() override {
107+
auto module_op = getOperation();
108+
109+
module_op.walk([&](Operation *op) {
110+
Region *region = nullptr;
111+
if (auto func_op = dyn_cast<func::FuncOp>(op)) {
112+
auto accel_attr = func_op->getAttrOfType<StringAttr>("accelerator");
113+
if (!accel_attr || accel_attr.getValue() != "neura") {
114+
return;
115+
}
116+
region = &func_op.getBody();
117+
} else if (auto llvm_func = dyn_cast<LLVM::LLVMFuncOp>(op)) {
118+
auto accel_attr = llvm_func->getAttrOfType<StringAttr>("accelerator");
119+
if (!accel_attr || accel_attr.getValue() != "neura") {
120+
return;
121+
}
122+
region = &llvm_func.getBody();
123+
} else {
124+
return;
125+
}
126+
127+
if (!region || region->empty()) {
128+
return;
129+
}
130+
131+
if (failed(canonicalizeCast(*region))) {
132+
signalPassFailure();
133+
return;
134+
}
135+
});
136+
}
137+
};
138+
} // namespace
139+
140+
namespace mlir::neura {
141+
std::unique_ptr<mlir::Pass> createCanonicalizeCastPass() {
142+
return std::make_unique<CanonicalizeCastPass>();
143+
}
144+
} // namespace mlir::neura

test/controflow_fuse/perfect_nested/perfect_nested.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s --lower-affine --convert-scf-to-cf --convert-cf-to-llvm -o %t-llvm.mlir
22
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura | FileCheck %s
3+
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura --canonicalize-cast | FileCheck %s --check-prefix=CAST
34
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura --leverage-predicated-value --transform-ctrl-to-data-flow | FileCheck %s -check-prefix=CTRL2DATA
45

56
module attributes {} {
@@ -45,6 +46,31 @@ module attributes {} {
4546
// CHECK-NEXT: "neura.return"() : () -> ()
4647
// CHECK-NEXT: }
4748

49+
// CAST: func.func @_Z10bert_node1PA1_A1_A1_A1_A128_bPA1_A128_S1_(%arg0: memref<?x1x1x1x1x128xi8>, %arg1: memref<?x1x128x1x1x128xi8>) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
50+
// CAST-NEXT: %0 = "neura.constant"() <{predicate = true, value = 1 : i64}> : () -> i64
51+
// CAST-NEXT: %1 = "neura.constant"() <{predicate = true, value = 128 : i64}> : () -> i64
52+
// CAST-NEXT: %2 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> i64
53+
// CAST-NEXT: neura.br %2 : i64 to ^bb1
54+
// CAST-NEXT: ^bb1(%3: i64): // 2 preds: ^bb0, ^bb5
55+
// CAST-NEXT: %4 = "neura.icmp"(%3, %1) <{cmpType = "slt"}> : (i64, i64) -> i1
56+
// CAST-NEXT: neura.cond_br %4 : i1 then to ^bb2 else to ^bb6
57+
// CAST-NEXT: ^bb2: // pred: ^bb1
58+
// CAST-NEXT: neura.br %2 : i64 to ^bb3
59+
// CAST-NEXT: ^bb3(%5: i64): // 2 preds: ^bb2, ^bb4
60+
// CAST-NEXT: %6 = "neura.icmp"(%5, %1) <{cmpType = "slt"}> : (i64, i64) -> i1
61+
// CAST-NEXT: neura.cond_br %6 : i1 then to ^bb4 else to ^bb5
62+
// CAST-NEXT: ^bb4: // pred: ^bb3
63+
// CAST-NEXT: %7 = neura.load_indexed %arg0[%2, %2, %2, %2, %2, %5 : i64, i64, i64, i64, i64, i64] memref<?x1x1x1x1x128xi8> : i8
64+
// CAST-NEXT: neura.store_indexed %7 to %arg1[%2, %2, %3, %2, %2, %5 : i64, i64, i64, i64, i64, i64] memref<?x1x128x1x1x128xi8> : i8
65+
// CAST-NEXT: %8 = "neura.add"(%5, %0) : (i64, i64) -> i64
66+
// CAST-NEXT: neura.br %8 : i64 to ^bb3
67+
// CAST-NEXT: ^bb5: // pred: ^bb3
68+
// CAST-NEXT: %9 = "neura.add"(%3, %0) : (i64, i64) -> i64
69+
// CAST-NEXT: neura.br %9 : i64 to ^bb1
70+
// CAST-NEXT: ^bb6: // pred: ^bb1
71+
// CAST-NEXT: "neura.return"() : () -> ()
72+
// CAST-NEXT: }
73+
4874
// CTRL2DATA: func.func @_Z10bert_node1PA1_A1_A1_A1_A128_bPA1_A128_S1_(%arg0: memref<?x1x1x1x1x128xi8>, %arg1: memref<?x1x128x1x1x128xi8>) attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
4975
// CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 1 : index}> : () -> !neura.data<index, i1>
5076
// CTRL2DATA-NEXT: %1 = "neura.grant_always"(%0) : (!neura.data<index, i1>) -> !neura.data<index, i1>

test/controflow_fuse/perfect_reduction/perfect_reduction.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s --lower-affine --convert-scf-to-cf --convert-cf-to-llvm -o %t-llvm.mlir
22
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura | FileCheck %s
3+
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura --canonicalize-cast | FileCheck %s --check-prefix=CAST
34
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura --leverage-predicated-value --transform-ctrl-to-data-flow | FileCheck %s -check-prefix=CTRL2DATA
45

56
module attributes {} {
@@ -50,6 +51,32 @@ module attributes {} {
5051
// CHECK-NEXT: "neura.return"(%6) : (i32) -> ()
5152
// CHECK-NEXT: }
5253

54+
// CAST: func.func @_Z27perfect_nested_reduction_2dPA128_i(%arg0: memref<?x128xi32>) -> i32 attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
55+
// CAST-NEXT: %0 = "neura.constant"() <{predicate = true, value = 1 : i64}> : () -> i64
56+
// CAST-NEXT: %1 = "neura.constant"() <{predicate = true, value = 128 : i64}> : () -> i64
57+
// CAST-NEXT: %2 = "neura.constant"() <{predicate = true, value = 0 : i32}> : () -> i32
58+
// CAST-NEXT: %3 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> i64
59+
// CAST-NEXT: neura.br %3, %2 : i64, i32 to ^bb1
60+
// CAST-NEXT: ^bb1(%4: i64, %5: i32): // 2 preds: ^bb0, ^bb5
61+
// CAST-NEXT: %6 = "neura.icmp"(%4, %1) <{cmpType = "slt"}> : (i64, i64) -> i1
62+
// CAST-NEXT: neura.cond_br %6 : i1 then to ^bb2 else to ^bb6
63+
// CAST-NEXT: ^bb2: // pred: ^bb1
64+
// CAST-NEXT: neura.br %3, %5 : i64, i32 to ^bb3
65+
// CAST-NEXT: ^bb3(%7: i64, %8: i32): // 2 preds: ^bb2, ^bb4
66+
// CAST-NEXT: %9 = "neura.icmp"(%7, %1) <{cmpType = "slt"}> : (i64, i64) -> i1
67+
// CAST-NEXT: neura.cond_br %9 : i1 then to ^bb4 else to ^bb5
68+
// CAST-NEXT: ^bb4: // pred: ^bb3
69+
// CAST-NEXT: %10 = neura.load_indexed %arg0[%4, %7 : i64, i64] memref<?x128xi32> : i32
70+
// CAST-NEXT: %11 = "neura.add"(%8, %10) : (i32, i32) -> i32
71+
// CAST-NEXT: %12 = "neura.add"(%7, %0) : (i64, i64) -> i64
72+
// CAST-NEXT: neura.br %12, %11 : i64, i32 to ^bb3
73+
// CAST-NEXT: ^bb5: // pred: ^bb3
74+
// CAST-NEXT: %13 = "neura.add"(%4, %0) : (i64, i64) -> i64
75+
// CAST-NEXT: neura.br %13, %8 : i64, i32 to ^bb1
76+
// CAST-NEXT: ^bb6: // pred: ^bb1
77+
// CAST-NEXT: "neura.return"(%5) : (i32) -> ()
78+
// CAST-NEXT: }
79+
5380
// CTRL2DATA: func.func @_Z27perfect_nested_reduction_2dPA128_i(%arg0: memref<?x128xi32>) -> i32 attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
5481
// CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 1 : index}> : () -> !neura.data<index, i1>
5582
// CTRL2DATA-NEXT: %1 = "neura.grant_always"(%0) : (!neura.data<index, i1>) -> !neura.data<index, i1>

test/controflow_fuse/simpleloop/simpleloop.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: mlir-opt %s --lower-affine --convert-scf-to-cf --convert-cf-to-llvm -o %t-llvm.mlir
22
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura | FileCheck %s
3+
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura --canonicalize-cast | FileCheck %s --check-prefix=CAST
34
// RUN: mlir-neura-opt %t-llvm.mlir --assign-accelerator --lower-arith-to-neura --lower-memref-to-neura --lower-builtin-to-neura --lower-llvm-to-neura --leverage-predicated-value --transform-ctrl-to-data-flow | FileCheck %s -check-prefix=CTRL2DATA
45

56
module attributes {} {
@@ -35,6 +36,23 @@ module attributes {} {
3536
// CHECK-NEXT: "neura.return"(%6) : (i32) -> ()
3637
// CHECK-NEXT: }
3738

39+
// CAST: func.func @_Z10simpleloopv() -> i32 attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
40+
// CAST-NEXT: %0 = "neura.constant"() <{predicate = true, value = 1 : i64}> : () -> i64
41+
// CAST-NEXT: %1 = "neura.constant"() <{predicate = true, value = 128 : i64}> : () -> i64
42+
// CAST-NEXT: %2 = "neura.constant"() <{predicate = true, value = 0 : i32}> : () -> i32
43+
// CAST-NEXT: %3 = "neura.constant"() <{predicate = true, value = 0 : i64}> : () -> i64
44+
// CAST-NEXT: neura.br %3, %2 : i64, i32 to ^bb1
45+
// CAST-NEXT: ^bb1(%4: i64, %5: i32): // 2 preds: ^bb0, ^bb2
46+
// CAST-NEXT: %6 = "neura.icmp"(%4, %1) <{cmpType = "slt"}> : (i64, i64) -> i1
47+
// CAST-NEXT: neura.cond_br %6 : i1 then to ^bb2 else to ^bb3
48+
// CAST-NEXT: ^bb2: // pred: ^bb1
49+
// CAST-NEXT: %7 = "neura.cast"(%4) <{cast_type = "i64_to_i32"}> : (i64) -> i32
50+
// CAST-NEXT: %8 = "neura.add"(%5, %7) : (i32, i32) -> i32
51+
// CAST-NEXT: %9 = "neura.add"(%4, %0) : (i64, i64) -> i64
52+
// CAST-NEXT: neura.br %9, %8 : i64, i32 to ^bb1
53+
// CAST-NEXT: ^bb3: // pred: ^bb1
54+
// CAST-NEXT: "neura.return"(%5) : (i32) -> ()
55+
// CAST-NEXT: }
3856

3957
// CTRL2DATA: func.func @_Z10simpleloopv() -> i32 attributes {accelerator = "neura", llvm.linkage = #llvm.linkage<external>} {
4058
// CTRL2DATA-NEXT: %0 = "neura.constant"() <{predicate = true, value = 1 : index}> : () -> !neura.data<index, i1>

0 commit comments

Comments
 (0)