Skip to content

Commit

Permalink
[CIR][ThroughMLIR] fix BinOp, CmpOp Lowering to MLIR and lowering cir…
Browse files Browse the repository at this point in the history
….vec.cmp to MLIR (#694)

This PR does Three things:
1. Fixes the BinOp lowering to MLIR issue where signed numbers were not
handled correctly, and adds support for vector types. The corresponding
test files have been modified.
2. Fixes the CmpOp lowering to MLIR issue where signed numbers were not
handled correctly And modified test files.
3. Adds cir.vec.cmp lowering to MLIR along with the corresponding test
files.

I originally planned to complete the remaining cir.vec.* lowerings in
this PR, but it seems there's quite a lot to do, so I'll split it into
multiple PRs.

---------

Co-authored-by: Kritoooo <[email protected]>
  • Loading branch information
2 people authored and lanza committed Jun 21, 2024
1 parent 9ea95d4 commit 9dbce79
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 392 deletions.
243 changes: 81 additions & 162 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,52 +628,60 @@ class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> {
"inconsistent operands' types not supported yet");
mlir::Type mlirType = getTypeConverter()->convertType(op.getType());
assert((mlirType.isa<mlir::IntegerType>() ||
mlirType.isa<mlir::FloatType>()) &&
mlirType.isa<mlir::FloatType>() ||
mlirType.isa<mlir::VectorType>()) &&
"operand type not supported yet");

auto type = op.getLhs().getType();
if (auto VecType = type.dyn_cast<mlir::cir::VectorType>()) {
type = VecType.getEltType();
}

switch (op.getKind()) {
case mlir::cir::BinOpKind::Add:
if (mlirType.isa<mlir::IntegerType>())
if (type.isa<mlir::cir::IntType>())
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
else
rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
break;
case mlir::cir::BinOpKind::Sub:
if (mlirType.isa<mlir::IntegerType>())
if (type.isa<mlir::cir::IntType>())
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
else
rewriter.replaceOpWithNewOp<mlir::arith::SubFOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
break;
case mlir::cir::BinOpKind::Mul:
if (mlirType.isa<mlir::IntegerType>())
if (type.isa<mlir::cir::IntType>())
rewriter.replaceOpWithNewOp<mlir::arith::MulIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
else
rewriter.replaceOpWithNewOp<mlir::arith::MulFOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
break;
case mlir::cir::BinOpKind::Div:
if (mlirType.isa<mlir::IntegerType>()) {
if (mlirType.isSignlessInteger())
if (auto ty = type.dyn_cast<mlir::cir::IntType>()) {
if (ty.isUnsigned())
rewriter.replaceOpWithNewOp<mlir::arith::DivUIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
else
llvm_unreachable("integer mlirType not supported in CIR yet");
rewriter.replaceOpWithNewOp<mlir::arith::DivSIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
} else
rewriter.replaceOpWithNewOp<mlir::arith::DivFOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
break;
case mlir::cir::BinOpKind::Rem:
if (mlirType.isa<mlir::IntegerType>()) {
if (mlirType.isSignlessInteger())
if (auto ty = type.dyn_cast<mlir::cir::IntType>()) {
if (ty.isUnsigned())
rewriter.replaceOpWithNewOp<mlir::arith::RemUIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
else
llvm_unreachable("integer mlirType not supported in CIR yet");
rewriter.replaceOpWithNewOp<mlir::arith::RemSIOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
} else
rewriter.replaceOpWithNewOp<mlir::arith::RemFOp>(
op, mlirType, adaptor.getLhs(), adaptor.getRhs());
Expand Down Expand Up @@ -703,144 +711,22 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
mlir::LogicalResult
matchAndRewrite(mlir::cir::CmpOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = adaptor.getLhs().getType();
auto integerType =
mlir::IntegerType::get(getContext(), 1, mlir::IntegerType::Signless);
auto type = op.getLhs().getType();

mlir::Value mlirResult;
switch (op.getKind()) {
case mlir::cir::CmpOpKind::gt: {
if (type.isa<mlir::IntegerType>()) {
mlir::arith::CmpIPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::arith::CmpIPredicate::ugt;
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), integerType,
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), integerType,
mlir::arith::CmpFPredicateAttr::get(
getContext(), mlir::arith::CmpFPredicate::UGT),
adaptor.getLhs(), adaptor.getRhs(),
mlir::arith::FastMathFlagsAttr::get(
getContext(), mlir::arith::FastMathFlags::none));
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::ge: {
if (type.isa<mlir::IntegerType>()) {
mlir::arith::CmpIPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::arith::CmpIPredicate::uge;
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), integerType,
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), integerType,
mlir::arith::CmpFPredicateAttr::get(
getContext(), mlir::arith::CmpFPredicate::UGE),
adaptor.getLhs(), adaptor.getRhs(),
mlir::arith::FastMathFlagsAttr::get(
getContext(), mlir::arith::FastMathFlags::none));
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::lt: {
if (type.isa<mlir::IntegerType>()) {
mlir::arith::CmpIPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::arith::CmpIPredicate::ult;
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), integerType,
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), integerType,
mlir::arith::CmpFPredicateAttr::get(
getContext(), mlir::arith::CmpFPredicate::ULT),
adaptor.getLhs(), adaptor.getRhs(),
mlir::arith::FastMathFlagsAttr::get(
getContext(), mlir::arith::FastMathFlags::none));
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::le: {
if (type.isa<mlir::IntegerType>()) {
mlir::arith::CmpIPredicate cmpIType;
if (!type.isSignlessInteger())
llvm_unreachable("integer type not supported in CIR yet");
cmpIType = mlir::arith::CmpIPredicate::ule;
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), integerType,
mlir::arith::CmpIPredicateAttr::get(getContext(), cmpIType),
adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), integerType,
mlir::arith::CmpFPredicateAttr::get(
getContext(), mlir::arith::CmpFPredicate::ULE),
adaptor.getLhs(), adaptor.getRhs(),
mlir::arith::FastMathFlagsAttr::get(
getContext(), mlir::arith::FastMathFlags::none));
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::eq: {
if (type.isa<mlir::IntegerType>()) {
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), integerType,
mlir::arith::CmpIPredicateAttr::get(getContext(),
mlir::arith::CmpIPredicate::eq),
adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), integerType,
mlir::arith::CmpFPredicateAttr::get(
getContext(), mlir::arith::CmpFPredicate::UEQ),
adaptor.getLhs(), adaptor.getRhs(),
mlir::arith::FastMathFlagsAttr::get(
getContext(), mlir::arith::FastMathFlags::none));
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}
case mlir::cir::CmpOpKind::ne: {
if (type.isa<mlir::IntegerType>()) {
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), integerType,
mlir::arith::CmpIPredicateAttr::get(getContext(),
mlir::arith::CmpIPredicate::ne),
adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), integerType,
mlir::arith::CmpFPredicateAttr::get(
getContext(), mlir::arith::CmpFPredicate::UNE),
adaptor.getLhs(), adaptor.getRhs(),
mlir::arith::FastMathFlagsAttr::get(
getContext(), mlir::arith::FastMathFlags::none));
} else {
llvm_unreachable("Unknown Operand Type");
}
break;
}

if (auto ty = type.dyn_cast<mlir::cir::IntType>()) {
auto kind = convertCmpKindToCmpIPredicate(op.getKind(), ty.isSigned());
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ty = type.dyn_cast<mlir::cir::CIRFPTypeInterface>()) {
auto kind = convertCmpKindToCmpFPredicate(op.getKind());
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ty = type.dyn_cast<mlir::cir::PointerType>()) {
llvm_unreachable("pointer comparison not supported yet");
} else {
return op.emitError() << "unsupported type for CmpOp: " << type;
}

// MLIR comparison ops return i1, but cir::CmpOp returns the same type as
Expand Down Expand Up @@ -1143,6 +1029,39 @@ class CIRVectorExtractLowering
}
};

class CIRVectorCmpOpLowering
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
public:
using OpConversionPattern<mlir::cir::VecCmpOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecCmpOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(op.getType().isa<mlir::cir::VectorType>() &&
op.getLhs().getType().isa<mlir::cir::VectorType>() &&
op.getRhs().getType().isa<mlir::cir::VectorType>() &&
"Vector compare with non-vector type");
auto elementType =
op.getLhs().getType().cast<mlir::cir::VectorType>().getEltType();
mlir::Value bitResult;
if (auto intType = elementType.dyn_cast<mlir::cir::IntType>()) {
bitResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(),
convertCmpKindToCmpIPredicate(op.getKind(), intType.isSigned()),
adaptor.getLhs(), adaptor.getRhs());
} else if (elementType.isa<mlir::cir::CIRFPTypeInterface>()) {
bitResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), convertCmpKindToCmpFPredicate(op.getKind()),
adaptor.getLhs(), adaptor.getRhs());
} else {
return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
}
rewriter.replaceOpWithNewOp<mlir::arith::ExtSIOp>(
op, typeConverter->convertType(op.getType()), bitResult);
return mlir::success();
}
};

class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> {
public:
using OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern;
Expand Down Expand Up @@ -1345,22 +1264,22 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &converter) {
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());

patterns
.add<CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering,
CIRBinOpLowering, CIRLoadOpLowering, CIRConstantOpLowering,
CIRStoreOpLowering, CIRAllocaOpLowering, CIRFuncOpLowering,
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
CIRYieldOpLowering, CIRCosOpLowering, CIRGlobalOpLowering,
CIRGetGlobalOpLowering, CIRCastOpLowering, CIRPtrStrideOpLowering,
CIRSqrtOpLowering, CIRCeilOpLowering, CIRExp2OpLowering,
CIRExpOpLowering, CIRFAbsOpLowering, CIRFloorOpLowering,
CIRLog10OpLowering, CIRLog2OpLowering, CIRLogOpLowering,
CIRRoundOpLowering, CIRPtrStrideOpLowering, CIRSinOpLowering,
CIRShiftOpLowering, CIRBitClzOpLowering, CIRBitCtzOpLowering,
CIRBitPopcountOpLowering, CIRBitClrsbOpLowering, CIRBitFfsOpLowering,
CIRBitParityOpLowering, CIRIfOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering>(
converter, patterns.getContext());
patterns.add<
CIRCmpOpLowering, CIRCallOpLowering, CIRUnaryOpLowering, CIRBinOpLowering,
CIRLoadOpLowering, CIRConstantOpLowering, CIRStoreOpLowering,
CIRAllocaOpLowering, CIRFuncOpLowering, CIRScopeOpLowering,
CIRBrCondOpLowering, CIRTernaryOpLowering, CIRYieldOpLowering,
CIRCosOpLowering, CIRGlobalOpLowering, CIRGetGlobalOpLowering,
CIRCastOpLowering, CIRPtrStrideOpLowering, CIRSqrtOpLowering,
CIRCeilOpLowering, CIRExp2OpLowering, CIRExpOpLowering, CIRFAbsOpLowering,
CIRFloorOpLowering, CIRLog10OpLowering, CIRLog2OpLowering,
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering,
CIRVectorCreateLowering, CIRVectorInsertLowering,
CIRVectorExtractLowering, CIRVectorCmpOpLowering>(converter,
patterns.getContext());
}

static mlir::TypeConverter prepareTypeConverter() {
Expand Down
43 changes: 43 additions & 0 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerToMLIRHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"

template <typename T>
mlir::Value getConst(mlir::ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -37,4 +38,46 @@ mlir::Value createIntCast(mlir::ConversionPatternRewriter &rewriter,
return rewriter.create<mlir::arith::BitcastOp>(loc, dstTy, src);
}

mlir::arith::CmpIPredicate
convertCmpKindToCmpIPredicate(mlir::cir::CmpOpKind kind, bool isSigned) {
using CIR = mlir::cir::CmpOpKind;
using arithCmpI = mlir::arith::CmpIPredicate;
switch (kind) {
case CIR::eq:
return arithCmpI::eq;
case CIR::ne:
return arithCmpI::ne;
case CIR::lt:
return (isSigned ? arithCmpI::slt : arithCmpI::ult);
case CIR::le:
return (isSigned ? arithCmpI::sle : arithCmpI::ule);
case CIR::gt:
return (isSigned ? arithCmpI::sgt : arithCmpI::ugt);
case CIR::ge:
return (isSigned ? arithCmpI::sge : arithCmpI::uge);
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::arith::CmpFPredicate
convertCmpKindToCmpFPredicate(mlir::cir::CmpOpKind kind) {
using CIR = mlir::cir::CmpOpKind;
using arithCmpF = mlir::arith::CmpFPredicate;
switch (kind) {
case CIR::eq:
return arithCmpF::OEQ;
case CIR::ne:
return arithCmpF::UNE;
case CIR::lt:
return arithCmpF::OLT;
case CIR::le:
return arithCmpF::OLE;
case CIR::gt:
return arithCmpF::OGT;
case CIR::ge:
return arithCmpF::OGE;
}
llvm_unreachable("Unknown CmpOpKind");
}

#endif
Loading

0 comments on commit 9dbce79

Please sign in to comment.