1+ #include " NeuraDialect/NeuraOps.h"
2+ #include " mlir/Dialect/Func/IR/FuncOps.h"
3+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
4+ #include " mlir/IR/Block.h"
5+ #include " mlir/IR/BuiltinAttributes.h"
6+ #include " mlir/IR/Operation.h"
7+ #include " mlir/IR/Region.h"
8+ #include " mlir/IR/Value.h"
9+ #include " mlir/Pass/Pass.h"
10+ #include " mlir/Support/LLVM.h"
11+
12+ using namespace mlir ;
13+
14+ namespace {
15+
16+ LogicalResult canonicalizeCast (Region ®ion) {
17+ // Handles block arguments.
18+ for (Block &block : region.getBlocks ()) {
19+ for (BlockArgument arg : block.getArguments ()) {
20+ if (arg.getType ().isIndex ()) {
21+ // Replaces index type with i64.
22+ arg.setType (IntegerType::get (arg.getContext (), 64 ));
23+ }
24+ }
25+ }
26+
27+ region.walk ([&](Operation *op) {
28+ // Handles the value attributes in neura::ConstantOp.
29+ if (isa<neura::ConstantOp>(op)) {
30+ Attribute value_attr = op->getAttr (" value" );
31+ if (!value_attr) {
32+ return ;
33+ }
34+ if (IntegerAttr int_attr = dyn_cast<IntegerAttr>(value_attr)) {
35+ if (isa<IntegerType>(op->getResult (0 ).getType ())) {
36+ return ;
37+ }
38+ if (isa<IndexType>(op->getResult (0 ).getType ())) {
39+ IntegerAttr new_attr = IntegerAttr::get (
40+ IntegerType::get (op->getContext (), 64 ), int_attr.getInt ());
41+ op->setAttr (" value" , new_attr);
42+ }
43+ }
44+ }
45+
46+ // Replaces all index types with i64.
47+ for (OpResult result : op->getOpResults ()) {
48+ auto type = result.getType ();
49+ if (isa<IndexType>(type)) {
50+ result.setType (mlir::IntegerType::get (op->getContext (), 64 ));
51+ }
52+ }
53+
54+ if (neura::CastOp cast_op = dyn_cast<neura::CastOp>(op)) {
55+ StringAttr cast_type_attr =
56+ cast_op->getAttrOfType <StringAttr>(" cast_type" );
57+ if (!cast_type_attr)
58+ return ;
59+ StringRef cast_type = cast_type_attr.getValue ();
60+
61+ Type src_type = cast_op->getOperand (0 ).getType ();
62+ Type dst_type = cast_op->getResult (0 ).getType ();
63+
64+ // Reomoves the index->i64 or i64->index cast operations.
65+ if ((cast_type == " index_to_int" && isa<IntegerType>(src_type) &&
66+ isa<IntegerType>(dst_type) &&
67+ dyn_cast<IntegerType>(src_type).getWidth () == 64 &&
68+ dyn_cast<IntegerType>(dst_type).getWidth () == 64 ) ||
69+ (cast_type == " int_to_index" && isa<IntegerType>(src_type) &&
70+ isa<IntegerType>(dst_type) &&
71+ dyn_cast<IntegerType>(src_type).getWidth () == 64 &&
72+ dyn_cast<IntegerType>(dst_type).getWidth () == 64 )) {
73+ cast_op->getResult (0 ).replaceAllUsesWith (cast_op->getOperand (0 ));
74+ cast_op->erase ();
75+ return ;
76+ }
77+
78+ // Changes index->i32 or i32->index casts to i64->i32 or i32->i64.
79+ if (cast_type == " index_to_int" && isa<IntegerType>(dst_type) &&
80+ dyn_cast<IntegerType>(dst_type).getWidth () == 32 ) {
81+ cast_op->setAttr (" cast_type" ,
82+ StringAttr::get (op->getContext (), " i64_to_i32" ));
83+ return ;
84+ }
85+ if (cast_type == " int_to_index" && isa<IntegerType>(src_type) &&
86+ dyn_cast<IntegerType>(src_type).getWidth () == 32 ) {
87+ cast_op->setAttr (" cast_type" ,
88+ StringAttr::get (op->getContext (), " i32_to_i64" ));
89+ return ;
90+ }
91+ // TODO: Handles other cast types if needed.
92+ }
93+ });
94+ return success ();
95+ }
96+
97+ struct CanonicalizeCastPass
98+ : public PassWrapper<CanonicalizeCastPass, OperationPass<ModuleOp>> {
99+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (CanonicalizeCastPass)
100+ StringRef getArgument () const override { return " canonicalize-cast" ; }
101+ StringRef getDescription () const override {
102+ return " Canonicalizes cast operations in the Neura dialect, specifically "
103+ " removing unnecessary index to i64 casts and vice versa." ;
104+ }
105+
106+ void runOnOperation () override {
107+ auto module_op = getOperation ();
108+
109+ module_op.walk ([&](Operation *op) {
110+ Region *region = nullptr ;
111+ if (auto func_op = dyn_cast<func::FuncOp>(op)) {
112+ auto accel_attr = func_op->getAttrOfType <StringAttr>(" accelerator" );
113+ if (!accel_attr || accel_attr.getValue () != " neura" ) {
114+ return ;
115+ }
116+ region = &func_op.getBody ();
117+ } else if (auto llvm_func = dyn_cast<LLVM::LLVMFuncOp>(op)) {
118+ auto accel_attr = llvm_func->getAttrOfType <StringAttr>(" accelerator" );
119+ if (!accel_attr || accel_attr.getValue () != " neura" ) {
120+ return ;
121+ }
122+ region = &llvm_func.getBody ();
123+ } else {
124+ return ;
125+ }
126+
127+ if (!region || region->empty ()) {
128+ return ;
129+ }
130+
131+ if (failed (canonicalizeCast (*region))) {
132+ signalPassFailure ();
133+ return ;
134+ }
135+ });
136+ }
137+ };
138+ } // namespace
139+
140+ namespace mlir ::neura {
141+ std::unique_ptr<mlir::Pass> createCanonicalizeCastPass () {
142+ return std::make_unique<CanonicalizeCastPass>();
143+ }
144+ } // namespace mlir::neura
0 commit comments