11#include " Conversion/LlvmToNeura/LlvmToNeura.h"
2+ #include " Common/AcceleratorAttrs.h"
23#include " NeuraDialect/NeuraDialect.h"
34#include " NeuraDialect/NeuraOps.h"
45#include " mlir/Dialect/LLVMIR/LLVMAttrs.h"
@@ -75,6 +76,110 @@ struct LlvmVFMulToNeuraVFMul: public OpRewritePattern<mlir::LLVM::FMulOp> {
7576 }
7677};
7778
79+ struct LlvmICmpToNeuraICmp : public OpRewritePattern <LLVM::ICmpOp> {
80+ using OpRewritePattern::OpRewritePattern;
81+
82+ LogicalResult matchAndRewrite (LLVM::ICmpOp op,
83+ PatternRewriter &rewriter) const override {
84+ auto pred = op.getPredicate ();
85+ auto lhs = op.getLhs ();
86+ auto rhs = op.getRhs ();
87+ auto resultType = op.getType ();
88+
89+ rewriter.replaceOpWithNewOp <neura::ICmpOp>(
90+ op, resultType, lhs, rhs, rewriter.getStringAttr (LLVM::stringifyICmpPredicate (pred)));
91+ return success ();
92+ }
93+ };
94+
95+ struct LlvmGEPToNeuraGEP : public OpRewritePattern <mlir::LLVM::GEPOp> {
96+ using OpRewritePattern::OpRewritePattern;
97+
98+ LogicalResult matchAndRewrite (mlir::LLVM::GEPOp op,
99+ PatternRewriter &rewriter) const override {
100+ Value base = op.getBase ();
101+ SmallVector<Value> indexValues;
102+
103+ for (auto gepIndex : op.getIndices ()) {
104+ if (auto val = gepIndex.dyn_cast <Value>()) {
105+ indexValues.push_back (val);
106+ } else if (auto intAttr = gepIndex.dyn_cast <IntegerAttr>()) {
107+ auto cst = rewriter.create <neura::ConstantOp>(
108+ op.getLoc (), rewriter.getIndexType (), intAttr);
109+ indexValues.push_back (cst);
110+ } else {
111+ return op.emitOpError (" Unsupported GEP index kind" );
112+ }
113+ }
114+
115+ rewriter.replaceOpWithNewOp <neura::GEP>(op, op.getType (), base, indexValues);
116+ return success ();
117+ }
118+ };
119+
120+ struct LlvmLoadToNeuraLoad : public OpRewritePattern <mlir::LLVM::LoadOp> {
121+ using OpRewritePattern::OpRewritePattern;
122+
123+ LogicalResult matchAndRewrite (mlir::LLVM::LoadOp op,
124+ PatternRewriter &rewriter) const override {
125+ Value ptr = op.getAddr (); // getPointer() is deprecated
126+ Type resultType = op.getResult ().getType ();
127+ rewriter.replaceOpWithNewOp <neura::LoadOp>(op, resultType, ptr);
128+ return success ();
129+ }
130+ };
131+
132+ struct LlvmStoreToNeuraStore : public OpRewritePattern <mlir::LLVM::StoreOp> {
133+ using OpRewritePattern::OpRewritePattern;
134+
135+ LogicalResult matchAndRewrite (mlir::LLVM::StoreOp op,
136+ PatternRewriter &rewriter) const override {
137+ Value value = op.getValue ();
138+ Value addr = op.getAddr (); // getPointer() is deprecated
139+ rewriter.replaceOpWithNewOp <neura::StoreOp>(op, value, addr);
140+ return success ();
141+ }
142+ };
143+
144+ struct LlvmCondBrToNeuraCondBr : public OpRewritePattern <LLVM::CondBrOp> {
145+ using OpRewritePattern::OpRewritePattern;
146+ LogicalResult matchAndRewrite (LLVM::CondBrOp op,
147+ PatternRewriter &rewriter) const override {
148+ // Get the source operation's successors (basic blocks)
149+ Block *trueDest = op.getTrueDest ();
150+ Block *falseDest = op.getFalseDest ();
151+
152+ // Get the operands for each destination
153+ ValueRange trueOperands = op.getTrueDestOperands ();
154+ ValueRange falseOperands = op.getFalseDestOperands ();
155+
156+ // Create the new operation with proper successors
157+ auto newOp = rewriter.create <neura::CondBr>(
158+ op.getLoc (), // Location
159+ op.getCondition (), // Condition
160+ trueOperands, // True destination operands
161+ falseOperands, // False destination operands
162+ trueDest, // True destination block
163+ falseDest // False destination block
164+ );
165+
166+ // Replace the old op with the new one
167+ rewriter.replaceOp (op, newOp->getResults ());
168+
169+ return success ();
170+ }
171+ };
172+
173+ struct LlvmReturnToNeuraReturn : public OpRewritePattern <LLVM::ReturnOp> {
174+ using OpRewritePattern::OpRewritePattern;
175+
176+ LogicalResult matchAndRewrite (LLVM::ReturnOp op,
177+ PatternRewriter &rewriter) const override {
178+ rewriter.replaceOpWithNewOp <neura::ReturnOp>(op, op.getOperands ());
179+ return success ();
180+ }
181+ };
182+
78183struct LowerLlvmToNeuraPass
79184 : public PassWrapper<LowerLlvmToNeuraPass, OperationPass<ModuleOp>> {
80185
@@ -96,17 +201,27 @@ struct LowerLlvmToNeuraPass
96201 patterns.add <LlvmAddToNeuraAdd>(&getContext ());
97202 patterns.add <LlvmFMulToNeuraFMul>(&getContext ());
98203 patterns.add <LlvmVFMulToNeuraVFMul>(&getContext ());
204+ patterns.add <LlvmICmpToNeuraICmp>(&getContext ());
205+ patterns.add <LlvmGEPToNeuraGEP>(&getContext ());
206+ patterns.add <LlvmLoadToNeuraLoad>(&getContext ());
207+ patterns.add <LlvmStoreToNeuraStore>(&getContext ());
208+ patterns.add <LlvmCondBrToNeuraCondBr>(&getContext ());
209+ patterns.add <LlvmReturnToNeuraReturn>(&getContext ());
210+
99211 FrozenRewritePatternSet frozen (std::move (patterns));
100212
101213 ModuleOp module_op = getOperation ();
102214
103215 // Applies to every region inside the module (regardless of func type,
104216 // e.g., mlir func or llvm func).
105- module_op.walk ([&](Operation *op) {
106- if (!op->getRegions ().empty ()) {
107- for (Region ®ion : op->getRegions ()) {
108- if (failed (applyPatternsAndFoldGreedily (region, frozen))) {
109- signalPassFailure ();
217+ module_op.walk ([&](FunctionOpInterface func) {
218+ if (func->hasAttr (mlir::accel::kAcceleratorAttr )) {
219+ auto target = func->getAttrOfType <StringAttr>(mlir::accel::kAcceleratorAttr );
220+ if (target && target.getValue () == mlir::accel::kNeuraTarget ) {
221+ for (Region ®ion : func->getRegions ()) {
222+ if (failed (applyPatternsAndFoldGreedily (region, frozen))) {
223+ signalPassFailure ();
224+ }
110225 }
111226 }
112227 }
0 commit comments