@@ -1496,29 +1496,118 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
14961496  }
14971497};
14981498
1499+ class  CIRSwitchOpLowering  : public  mlir ::OpConversionPattern<cir::SwitchOp> {
1500+ public: 
1501+   using  OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1502+ 
1503+   mlir::LogicalResult
1504+   matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1505+                   mlir::ConversionPatternRewriter &rewriter) const  override  {
1506+     rewriter.setInsertionPointAfter (op);
1507+     llvm::SmallVector<CaseOp> cases;
1508+     if  (!op.isSimpleForm (cases))
1509+       mlir::emitError (op.getLoc (), " not yet implemented" 
1510+ 
1511+     llvm::SmallVector<int64_t > caseValues;
1512+     //  Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1513+     //  This is necessary because some CaseOp might carry 0 or multiple values.
1514+     llvm::DenseMap<size_t , unsigned > indexMap;
1515+     caseValues.reserve (cases.size ());
1516+     for  (auto  [i, caseOp] : llvm::enumerate (cases)) {
1517+       switch  (caseOp.getKind ()) {
1518+       case  CaseOpKind::Equal: {
1519+         auto  valueAttr = caseOp.getValue ()[0 ];
1520+         auto  value = cast<cir::IntAttr>(valueAttr);
1521+         indexMap[i] = caseValues.size ();
1522+         caseValues.push_back (value.getUInt ());
1523+         break ;
1524+       }
1525+       case  CaseOpKind::Default:
1526+         break ;
1527+       case  CaseOpKind::Range:
1528+       case  CaseOpKind::Anyof:
1529+         llvm_unreachable (" NYI" 
1530+       }
1531+     }
1532+ 
1533+     auto  operand = adaptor.getOperands ()[0 ];
1534+     //  `scf.index_switch` expects an index of type `index`.
1535+     auto  indexType = mlir::IndexType::get (getContext ());
1536+     auto  indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1537+         op.getLoc (), indexType, operand);
1538+     auto  indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1539+         op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1540+ 
1541+     bool  metDefault = false ;
1542+     for  (auto  [i, caseOp] : llvm::enumerate (cases)) {
1543+       auto  ®ion = caseOp.getRegion ();
1544+       switch  (caseOp.getKind ()) {
1545+       case  CaseOpKind::Equal: {
1546+         auto  &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1547+         rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1548+         break ;
1549+       }
1550+       case  CaseOpKind::Default: {
1551+         auto  &defaultRegion = indexSwitch.getDefaultRegion ();
1552+         rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1553+         metDefault = true ;
1554+         break ;
1555+       }
1556+       case  CaseOpKind::Range:
1557+       case  CaseOpKind::Anyof:
1558+         llvm_unreachable (" NYI" 
1559+       }
1560+     }
1561+ 
1562+     //  `scf.index_switch` expects its default region to contain exactly one
1563+     //  block. If we don't have a default region in `cir.switch`, we need to
1564+     //  supply it here.
1565+     if  (!metDefault) {
1566+       auto  &defaultRegion = indexSwitch.getDefaultRegion ();
1567+       mlir::Block *block =
1568+           rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1569+       rewriter.setInsertionPointToEnd (block);
1570+       rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1571+     }
1572+ 
1573+     //  The final `cir.break` should be replaced to `scf.yield`.
1574+     //  After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1575+     for  (auto  ®ion : indexSwitch.getCaseRegions ()) {
1576+       auto  &lastBlock = region.back ();
1577+       auto  &lastOp = lastBlock.back ();
1578+       assert (isa<BreakOp>(lastOp));
1579+       rewriter.setInsertionPointAfter (&lastOp);
1580+       rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1581+     }
1582+ 
1583+     rewriter.replaceOp (op, indexSwitch);
1584+ 
1585+     return  mlir::success ();
1586+   }
1587+ };
1588+ 
14991589void  populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
15001590                                         mlir::TypeConverter &converter) {
15011591  patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
15021592
1503-   patterns
1504-       .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1505-            CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1506-            CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1507-            CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering,
1508-            CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1509-            CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1510-            CIRPtrStrideOpLowering, CIRGetElementOpLowering, CIRSqrtOpLowering,
1511-            CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1512-            CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1513-            CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1514-            CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1515-            CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1516-            CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1517-            CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1518-            CIRVectorInsertLowering, CIRVectorExtractLowering,
1519-            CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering,
1520-            CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(
1521-           converter, patterns.getContext ());
1593+   patterns.add <
1594+       CIRSwitchOpLowering, CIRGetElementOpLowering, CIRATanOpLowering,
1595+       CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1596+       CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1597+       CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1598+       CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1599+       CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1600+       CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1601+       CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1602+       CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1603+       CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1604+       CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1605+       CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1606+       CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1607+       CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1608+       CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1609+       CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1610+       CIRTrapOpLowering>(converter, patterns.getContext ());
15221611}
15231612
15241613static  mlir::TypeConverter prepareTypeConverter () {
@@ -1624,6 +1713,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
16241713
16251714  mlir::PassManager pm (mlirCtx);
16261715
1716+   pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
16271717  pm.addPass (createConvertCIRToMLIRPass ());
16281718  pm.addPass (createConvertMLIRToLLVMPass ());
16291719
@@ -1669,6 +1759,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
16691759
16701760  mlir::PassManager pm (mlirCtx);
16711761
1762+   pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
16721763  pm.addPass (createConvertCIRToMLIRPass ());
16731764
16741765  auto  result = !mlir::failed (pm.run (theModule));
0 commit comments