diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h index f4d4f1f555fa4..47b9b60b64b2e 100644 --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -41,6 +41,7 @@ namespace llvm { class MDNode; class raw_ostream; +class CmpPredicate; struct KnownBits; /// This class represents a range of values. @@ -106,7 +107,7 @@ class [[nodiscard]] ConstantRange { /// /// Example: Pred = ult and Other = i8 [2, 5) returns Result = [0, 4) LLVM_ABI static ConstantRange - makeAllowedICmpRegion(CmpInst::Predicate Pred, const ConstantRange &Other); + makeAllowedICmpRegion(CmpPredicate Pred, const ConstantRange &Other); /// Produce the largest range such that all values in the returned range /// satisfy the given predicate with all values contained within Other. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 171952120fc40..dc73bcf1b657b 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -10325,7 +10325,7 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned, computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false, UseInstrInfo, AC, I, DT, Depth + 1); CR = CR.intersectWith( - ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS)); + ConstantRange::makeAllowedICmpRegion(Cmp->getCmpPredicate(), RHS)); } } diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp index 9beaee60d0bc1..2a393cdc904f3 100644 --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/ADT/APInt.h" #include "llvm/Config/llvm-config.h" +#include "llvm/IR/CmpPredicate.h" #include "llvm/IR/Constants.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" @@ -106,12 +107,13 @@ std::pair ConstantRange::splitPosNeg() const { return {intersectWith(PosFilter), intersectWith(NegFilter)}; } -ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, +ConstantRange ConstantRange::makeAllowedICmpRegion(CmpPredicate Pred, const ConstantRange &CR) { if (CR.isEmptySet()) return CR; uint32_t W = CR.getBitWidth(); + ConstantRange Result = getFull(W); switch (Pred) { default: llvm_unreachable("Invalid ICmp predicate to makeAllowedICmpRegion()"); @@ -125,34 +127,56 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, APInt UMax(CR.getUnsignedMax()); if (UMax.isMinValue()) return getEmpty(W); - return ConstantRange(APInt::getMinValue(W), std::move(UMax)); + Result = ConstantRange(APInt::getMinValue(W), std::move(UMax)); + if (!Pred.hasSameSign()) + return Result; } + // For samesign, intersect with signed range. + LLVM_FALLTHROUGH; case CmpInst::ICMP_SLT: { APInt SMax(CR.getSignedMax()); if (SMax.isMinSignedValue()) return getEmpty(W); - return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax)); + return Result.intersectWith( + ConstantRange(APInt::getSignedMinValue(W), std::move(SMax))); } - case CmpInst::ICMP_ULE: - return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1); + case CmpInst::ICMP_ULE: { + Result = getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1); + if (!Pred.hasSameSign()) + return Result; + } + // For samesign, intersect with signed range. + LLVM_FALLTHROUGH; case CmpInst::ICMP_SLE: - return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1); + return Result.intersectWith( + getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1)); case CmpInst::ICMP_UGT: { APInt UMin(CR.getUnsignedMin()); if (UMin.isMaxValue()) return getEmpty(W); - return ConstantRange(std::move(UMin) + 1, APInt::getZero(W)); + Result = ConstantRange(std::move(UMin) + 1, APInt::getZero(W)); + if (!Pred.hasSameSign()) + return Result; } + // For samesign, intersect with signed range. + LLVM_FALLTHROUGH; case CmpInst::ICMP_SGT: { APInt SMin(CR.getSignedMin()); if (SMin.isMaxSignedValue()) return getEmpty(W); - return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W)); + return Result.intersectWith( + ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W))); } - case CmpInst::ICMP_UGE: - return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W)); + case CmpInst::ICMP_UGE: { + Result = getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W)); + if (!Pred.hasSameSign()) + return Result; + } + // For samesign, intersect with signed range. + LLVM_FALLTHROUGH; case CmpInst::ICMP_SGE: - return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W)); + return Result.intersectWith( + getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W))); } } diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp index 1707ed2c39264..ef839d07b174c 100644 --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -3234,6 +3234,106 @@ TEST_F(ValueTrackingTest, ComputeConstantRange) { EXPECT_EQ(10, CR2.getUpper()); } + { + // Assumptions: + // * stride >= 5 (unsigned) + // + // stride = [5, 0) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride) { + %gt = icmp uge i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); + EXPECT_EQ(5, CR2.getLower()); + EXPECT_EQ(0, CR2.getUpper()); + } + + { + // Assumptions: + // * stride > 5 (unsigned) + // + // stride = [5, 0) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride) { + %gt = icmp ugt i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); + EXPECT_EQ(6, CR2.getLower()); + EXPECT_EQ(0, CR2.getUpper()); + } + + { + // Assumptions: + // * stride >= 5 (samesign unsigned) + // + // stride = [5, MIN_SIGNED) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride) { + %gt = icmp samesign uge i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); + EXPECT_EQ(5, CR2.getLower()); + EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper()); + } + + { + // Assumptions: + // * stride > 5 (samesign unsigned) + // + // stride = [5, MIN_SIGNED) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride) { + %gt = icmp samesign ugt i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); + EXPECT_EQ(6, CR2.getLower()); + EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper()); + } + { // Assumptions: // * stride >= 5 diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp index 13712a76d3edf..bb6b2a1cb7249 100644 --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -10,6 +10,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/IR/CmpPredicate.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/Support/KnownBits.h" @@ -1730,6 +1731,64 @@ TEST(ConstantRange, MakeAllowedICmpRegionEdgeCases) { .isFullSet()); } +template +auto getSameSignTester(SIV ShouldIncludeValue, CmpInst::Predicate Cmp) { + return [Cmp, ShouldIncludeValue](const ConstantRange &CR) { + uint32_t BitWidth = CR.getBitWidth(); + unsigned Max = 1 << BitWidth; + SmallBitVector Elems(Max); + if (!CR.isEmptySet()) { + for (unsigned I : llvm::seq(Max)) { + APInt Current(BitWidth, I); + if (ShouldIncludeValue(Current, CR)) + Elems.set(I); + } + } + + CmpPredicate CmpPred(Cmp, true); + TestRange(ConstantRange::makeAllowedICmpRegion(CmpPred, CR), Elems, + PreferSmallest, {}); + }; +} + +TEST(ConstantRange, MakeAllowedICmpRegionExaustive) { + EnumerateInterestingConstantRanges(getSameSignTester( + [](const APInt &A, const ConstantRange &B) { + if (A.isNegative()) + return A.sge(B.getSignedMin()); + return A.uge(B.getUnsignedMin()); + }, + ICmpInst::ICMP_UGE)); + + EnumerateInterestingConstantRanges(getSameSignTester( + [](const APInt &A, const ConstantRange &B) { + if (A.isNegative()) + return A.sgt(B.getSignedMin()); + return A.ugt(B.getUnsignedMin()); + }, + ICmpInst::ICMP_UGT)); + + EnumerateInterestingConstantRanges(getSameSignTester( + [](const APInt &A, const ConstantRange &B) { + if (A.isNegative() && B.getUnsignedMax().isNegative()) + return A.sle(B.getUnsignedMax()); + if (A.isNonNegative() && B.getSignedMax().isNonNegative()) + return A.ule(B.getSignedMax()); + return false; + }, + ICmpInst::ICMP_ULE)); + + EnumerateInterestingConstantRanges(getSameSignTester( + [](const APInt &A, const ConstantRange &B) { + if (A.isNegative() && B.getUnsignedMax().isNegative()) + return A.slt(B.getUnsignedMax()); + if (A.isNonNegative() && B.getSignedMax().isNonNegative()) + return A.ult(B.getSignedMax()); + return false; + }, + ICmpInst::ICMP_ULT)); +} + TEST(ConstantRange, MakeExactICmpRegion) { for (unsigned Bits : {1, 4}) { EnumerateAPInts(Bits, [](const APInt &N) {