@@ -1444,28 +1444,119 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
14441444 }
14451445};
14461446
1447+ class CIRSwitchOpLowering : public mlir ::OpConversionPattern<cir::SwitchOp> {
1448+ public:
1449+ using OpConversionPattern<cir::SwitchOp>::OpConversionPattern;
1450+
1451+ mlir::LogicalResult
1452+ matchAndRewrite (cir::SwitchOp op, OpAdaptor adaptor,
1453+ mlir::ConversionPatternRewriter &rewriter) const override {
1454+ rewriter.setInsertionPointAfter (op);
1455+ llvm::SmallVector<CaseOp> cases;
1456+ if (!op.isSimpleForm (cases))
1457+ llvm_unreachable (" NYI" );
1458+
1459+ llvm::SmallVector<int64_t > caseValues;
1460+ // Maps the index of a CaseOp in `cases`, to the index in `caseValues`.
1461+ // This is necessary because some CaseOp might carry 0 or multiple values.
1462+ llvm::DenseMap<size_t , unsigned > indexMap;
1463+ caseValues.reserve (cases.size ());
1464+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1465+ switch (caseOp.getKind ()) {
1466+ case CaseOpKind::Equal: {
1467+ auto valueAttr = caseOp.getValue ()[0 ];
1468+ auto value = cast<cir::IntAttr>(valueAttr);
1469+ indexMap[i] = caseValues.size ();
1470+ caseValues.push_back (value.getUInt ());
1471+ break ;
1472+ }
1473+ case CaseOpKind::Default:
1474+ break ;
1475+ case CaseOpKind::Range:
1476+ case CaseOpKind::Anyof:
1477+ llvm_unreachable (" NYI" );
1478+ }
1479+ }
1480+
1481+ auto operand = adaptor.getOperands ()[0 ];
1482+ // `scf.index_switch` expects an index of type `index`.
1483+ auto indexType = mlir::IndexType::get (getContext ());
1484+ auto indexCast = rewriter.create <mlir::arith::IndexCastOp>(
1485+ op.getLoc (), indexType, operand);
1486+ auto indexSwitch = rewriter.create <mlir::scf::IndexSwitchOp>(
1487+ op.getLoc (), mlir::TypeRange{}, indexCast, caseValues, cases.size ());
1488+
1489+ bool metDefault = false ;
1490+ for (auto [i, caseOp] : llvm::enumerate (cases)) {
1491+ auto ®ion = caseOp.getRegion ();
1492+ switch (caseOp.getKind ()) {
1493+ case CaseOpKind::Equal: {
1494+ auto &caseRegion = indexSwitch.getCaseRegions ()[indexMap[i]];
1495+ rewriter.inlineRegionBefore (region, caseRegion, caseRegion.end ());
1496+ break ;
1497+ }
1498+ case CaseOpKind::Default: {
1499+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1500+ rewriter.inlineRegionBefore (region, defaultRegion, defaultRegion.end ());
1501+ metDefault = true ;
1502+ break ;
1503+ }
1504+ case CaseOpKind::Range:
1505+ case CaseOpKind::Anyof:
1506+ llvm_unreachable (" NYI" );
1507+ }
1508+ }
1509+
1510+ // `scf.index_switch` expects its default region to contain exactly one
1511+ // block. If we don't have a default region in `cir.switch`, we need to
1512+ // supply it here.
1513+ if (!metDefault) {
1514+ auto &defaultRegion = indexSwitch.getDefaultRegion ();
1515+ mlir::Block *block =
1516+ rewriter.createBlock (&defaultRegion, defaultRegion.end ());
1517+ rewriter.setInsertionPointToEnd (block);
1518+ rewriter.create <mlir::scf::YieldOp>(op.getLoc ());
1519+ }
1520+
1521+ // The final `cir.break` should be replaced to `scf.yield`.
1522+ // After MLIRLoweringPrepare pass, every case must end with a `cir.break`.
1523+ for (auto ®ion : indexSwitch.getCaseRegions ()) {
1524+ auto &lastBlock = region.back ();
1525+ auto &lastOp = lastBlock.back ();
1526+ assert (isa<BreakOp>(lastOp));
1527+ rewriter.setInsertionPointAfter (&lastOp);
1528+ rewriter.replaceOpWithNewOp <mlir::scf::YieldOp>(&lastOp);
1529+ }
1530+
1531+ rewriter.replaceOp (op, indexSwitch);
1532+
1533+ return mlir::success ();
1534+ }
1535+ };
1536+
14471537void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
14481538 mlir::TypeConverter &converter) {
14491539 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
14501540
14511541 patterns
1452- .add <CIRATanOpLowering, CIRCmpOpLowering, CIRCallOpLowering,
1453- CIRUnaryOpLowering, CIRBinOpLowering, CIRLoadOpLowering,
1454- CIRConstantOpLowering, CIRStoreOpLowering, CIRAllocaOpLowering,
1455- CIRFuncOpLowering, CIRScopeOpLowering, CIRBrCondOpLowering,
1456- CIRTernaryOpLowering, CIRYieldOpLowering, CIRCosOpLowering,
1457- CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRCastOpLowering,
1458- CIRPtrStrideOpLowering, CIRSqrtOpLowering, CIRCeilOpLowering,
1459- CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
1460- CIRAbsOpLowering, CIRFloorOpLowering, CIRLog10OpLowering,
1461- CIRLog2OpLowering, CIRLogOpLowering, CIRRoundOpLowering,
1462- CIRPtrStrideOpLowering, CIRSinOpLowering, CIRShiftOpLowering,
1463- CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitPopcountOpLowering,
1464- CIRBitClrsbOpLowering, CIRBitFfsOpLowering, CIRBitParityOpLowering,
1465- CIRIfOpLowering, CIRVectorCreateLowering, CIRVectorInsertLowering,
1466- CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
1467- CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
1468- CIRTrapOpLowering>(converter, patterns.getContext ());
1542+ .add <CIRSwitchOpLowering, CIRATanOpLowering, CIRCmpOpLowering,
1543+ CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
1544+ CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
1545+ CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
1546+ CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
1547+ CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
1548+ CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
1549+ CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering,
1550+ CIRFAbsOpLowering, CIRAbsOpLowering, CIRFloorOpLowering,
1551+ CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
1552+ CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
1553+ CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
1554+ CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
1555+ CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
1556+ CIRVectorInsertLowering, CIRVectorExtractLowering,
1557+ CIRVectorCmpOpLowering, CIRACosOpLowering, CIRASinOpLowering,
1558+ CIRUnreachableOpLowering, CIRTanOpLowering, CIRTrapOpLowering>(
1559+ converter, patterns.getContext ());
14691560}
14701561
14711562static mlir::TypeConverter prepareTypeConverter () {
@@ -1571,6 +1662,7 @@ mlir::ModuleOp lowerFromCIRToMLIRToLLVMDialect(mlir::ModuleOp theModule,
15711662
15721663 mlir::PassManager pm (mlirCtx);
15731664
1665+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
15741666 pm.addPass (createConvertCIRToMLIRPass ());
15751667 pm.addPass (createConvertMLIRToLLVMPass ());
15761668
@@ -1616,6 +1708,7 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
16161708
16171709 mlir::PassManager pm (mlirCtx);
16181710
1711+ pm.addPass (createMLIRCoreDialectsLoweringPreparePass ());
16191712 pm.addPass (createConvertCIRToMLIRPass ());
16201713
16211714 auto result = !mlir::failed (pm.run (theModule));
0 commit comments