@@ -73,8 +73,7 @@ struct LlvmFSubToNeuraFSub : public OpRewritePattern<mlir::LLVM::FSubOp> {
7373 }
7474
7575 // Sets optional predicate: default to 'none'.
76- rewriter.replaceOpWithNewOp <neura::FSubOp>(op, result_type, lhs, rhs,
77- Value ());
76+ rewriter.replaceOpWithNewOp <neura::FSubOp>(op, result_type, lhs, rhs);
7877 return success ();
7978 }
8079};
@@ -109,6 +108,35 @@ struct LlvmFMulToNeuraFMul : public OpRewritePattern<mlir::LLVM::FMulOp> {
109108 }
110109};
111110
111+ struct LlvmSDivToNeuraDiv : public OpRewritePattern <LLVM::SDivOp> {
112+ using OpRewritePattern::OpRewritePattern;
113+
114+ LogicalResult matchAndRewrite (LLVM::SDivOp op,
115+ PatternRewriter &rewriter) const override {
116+ Value lhs = op.getLhs ();
117+ Value rhs = op.getRhs ();
118+ Type resultType = op.getType ();
119+
120+ rewriter.replaceOpWithNewOp <neura::DivOp>(op, resultType, lhs, rhs);
121+ return success ();
122+ }
123+ };
124+
125+ struct LlvmSRemToNeuraRem : public OpRewritePattern <LLVM::SRemOp> {
126+ using OpRewritePattern<LLVM::SRemOp>::OpRewritePattern;
127+
128+ LogicalResult matchAndRewrite (LLVM::SRemOp op,
129+ PatternRewriter &rewriter) const override {
130+ Value lhs = op.getLhs ();
131+ Value rhs = op.getRhs ();
132+ Type resultType = op.getType ();
133+
134+ // Create neura.rem operation to replace llvm.srem
135+ rewriter.replaceOpWithNewOp <neura::RemOp>(op, resultType, lhs, rhs);
136+ return success ();
137+ }
138+ };
139+
112140struct LlvmVFMulToNeuraVFMul : public OpRewritePattern <mlir::LLVM::FMulOp> {
113141 using OpRewritePattern::OpRewritePattern;
114142
@@ -311,11 +339,11 @@ struct LlvmAllocaToNeuraAlloca : public OpRewritePattern<LLVM::AllocaOp> {
311339 PatternRewriter &rewriter) const override {
312340 Value size = op.getArraySize ();
313341 Type resultType = op.getType ();
314-
342+
315343 // Converts the size to neura.data<i32, i1> if it's not already.
316344 // Assumes the size is already in the right format.
317345 // Handles type conversion here.
318-
346+
319347 rewriter.replaceOpWithNewOp <neura::AllocaOp>(op, resultType, size);
320348 return success ();
321349 }
@@ -328,7 +356,7 @@ struct LlvmSExtToNeuraSExt : public OpRewritePattern<LLVM::SExtOp> {
328356 PatternRewriter &rewriter) const override {
329357 Value input = op.getArg ();
330358 Type resultType = op.getType ();
331-
359+
332360 rewriter.replaceOpWithNewOp <neura::SExtOp>(op, resultType, input);
333361 return success ();
334362 }
@@ -341,7 +369,7 @@ struct LlvmZExtToNeuraZExt : public OpRewritePattern<LLVM::ZExtOp> {
341369 PatternRewriter &rewriter) const override {
342370 Value input = op.getArg ();
343371 Type resultType = op.getType ();
344-
372+
345373 rewriter.replaceOpWithNewOp <neura::ZExtOp>(op, resultType, input);
346374 return success ();
347375 }
@@ -355,36 +383,48 @@ struct LlvmMulToNeuraMul : public OpRewritePattern<LLVM::MulOp> {
355383 Value lhs = op.getLhs ();
356384 Value rhs = op.getRhs ();
357385 Type resultType = op.getType ();
358-
386+
359387 rewriter.replaceOpWithNewOp <neura::MulOp>(op, resultType, lhs, rhs);
360388 return success ();
361389 }
362390};
363391
392+ struct LlvmShlToNeuraShl : public OpRewritePattern <LLVM::ShlOp> {
393+ using OpRewritePattern::OpRewritePattern;
394+
395+ LogicalResult matchAndRewrite (LLVM::ShlOp op,
396+ PatternRewriter &rewriter) const override {
397+ Value lhs = op.getLhs ();
398+ Value rhs = op.getRhs ();
399+ Type resultType = op.getType ();
400+
401+ rewriter.replaceOpWithNewOp <neura::ShlOp>(op, resultType, lhs, rhs);
402+ return success ();
403+ }
404+ };
405+
364406struct LlvmFuncToNeuraFunc : public OpRewritePattern <LLVM::LLVMFuncOp> {
365407 using OpRewritePattern::OpRewritePattern;
366408
367409 LogicalResult matchAndRewrite (LLVM::LLVMFuncOp op,
368410 PatternRewriter &rewriter) const override {
369411
370-
371412 auto target = op->getAttrOfType <StringAttr>(mlir::accel::kAcceleratorAttr );
372413 if (!target || target.getValue () != mlir::accel::kNeuraTarget ) {
373414 return failure ();
374415 }
375416
376417 // Converts LLVMFunctionType to FunctionType.
377418 auto llvmFuncType = op.getFunctionType ();
378- auto funcType = rewriter.getFunctionType (
379- llvmFuncType.getParams (),
380- llvmFuncType.getReturnType ()
381- );
419+ auto funcType = rewriter.getFunctionType (llvmFuncType.getParams (),
420+ llvmFuncType.getReturnType ());
382421
383- // Creates the new func.func operation using OperationState to have full control.
422+ // Creates the new func.func operation using OperationState to have full
423+ // control.
384424 OperationState state (op.getLoc (), func::FuncOp::getOperationName ());
385425 state.addAttribute (" sym_name" , rewriter.getStringAttr (op.getName ()));
386426 state.addAttribute (" function_type" , TypeAttr::get (funcType));
387-
427+
388428 // Copies ALL attributes from the original llvm.func exactly as they are.
389429 // Skips function type and name attributes as they are handled separately.
390430 SmallVector<NamedAttribute> attrs;
@@ -395,15 +435,16 @@ struct LlvmFuncToNeuraFunc : public OpRewritePattern<LLVM::LLVMFuncOp> {
395435 attrs.push_back (attr);
396436 }
397437 state.addAttributes (attrs);
398-
438+
399439 // Adds the function body region.
400440 state.addRegion ();
401-
441+
402442 auto newFunc = cast<func::FuncOp>(rewriter.create (state));
403443
404444 // Moves the function body.
405- rewriter.inlineRegionBefore (op.getBody (), newFunc.getBody (), newFunc.getBody ().end ());
406-
445+ rewriter.inlineRegionBefore (op.getBody (), newFunc.getBody (),
446+ newFunc.getBody ().end ());
447+
407448 // Replaces the old function.
408449 rewriter.replaceOp (op, newFunc);
409450 return success ();
@@ -435,20 +476,19 @@ struct LlvmCallToFuncCall : public OpRewritePattern<LLVM::CallOp> {
435476
436477 // Gets the result types from the function signature.
437478 auto resultTypes = funcOp.getFunctionType ().getResults ();
438-
479+
439480 // Converts the call to func.call.
440481 auto newCall = rewriter.create <func::CallOp>(
441- op.getLoc (), resultTypes, callee.value (), op.getArgOperands ()
442- );
443-
482+ op.getLoc (), resultTypes, callee.value (), op.getArgOperands ());
483+
444484 // Replaces the old call with the new one.
445485 // Handles both cases: calls with results and calls without results.
446486 if (op.getNumResults () == 0 ) {
447487 rewriter.eraseOp (op);
448488 } else {
449489 rewriter.replaceOp (op, newCall->getResults ());
450490 }
451-
491+
452492 return success ();
453493 }
454494};
@@ -494,6 +534,9 @@ struct LowerLlvmToNeuraPass
494534 patterns.add <LlvmMulToNeuraMul>(&getContext ());
495535 patterns.add <LlvmFuncToNeuraFunc>(&getContext ());
496536 patterns.add <LlvmCallToFuncCall>(&getContext ());
537+ patterns.add <LlvmShlToNeuraShl>(&getContext ());
538+ patterns.add <LlvmSDivToNeuraDiv>(&getContext ());
539+ patterns.add <LlvmSRemToNeuraRem>(&getContext ());
497540
498541 FrozenRewritePatternSet frozen (std::move (patterns));
499542
0 commit comments