Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/IR/ConstantRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace llvm {

class MDNode;
class raw_ostream;
class CmpPredicate;
struct KnownBits;

/// This class represents a range of values.
Expand Down Expand Up @@ -106,7 +107,7 @@ class [[nodiscard]] ConstantRange {
///
/// Example: Pred = ult and Other = i8 [2, 5) returns Result = [0, 4)
LLVM_ABI static ConstantRange
makeAllowedICmpRegion(CmpInst::Predicate Pred, const ConstantRange &Other);
makeAllowedICmpRegion(CmpPredicate Pred, const ConstantRange &Other);

/// Produce the largest range such that all values in the returned range
/// satisfy the given predicate with all values contained within Other.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10325,7 +10325,7 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false,
UseInstrInfo, AC, I, DT, Depth + 1);
CR = CR.intersectWith(
ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS));
ConstantRange::makeAllowedICmpRegion(Cmp->getCmpPredicate(), RHS));
}
}

Expand Down
46 changes: 35 additions & 11 deletions llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/IR/ConstantRange.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/CmpPredicate.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
Expand Down Expand Up @@ -106,12 +107,13 @@ std::pair<ConstantRange, ConstantRange> ConstantRange::splitPosNeg() const {
return {intersectWith(PosFilter), intersectWith(NegFilter)};
}

ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
ConstantRange ConstantRange::makeAllowedICmpRegion(CmpPredicate Pred,
const ConstantRange &CR) {
if (CR.isEmptySet())
return CR;

uint32_t W = CR.getBitWidth();
ConstantRange Result = getFull(W);
switch (Pred) {
default:
llvm_unreachable("Invalid ICmp predicate to makeAllowedICmpRegion()");
Expand All @@ -125,34 +127,56 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
APInt UMax(CR.getUnsignedMax());
if (UMax.isMinValue())
return getEmpty(W);
return ConstantRange(APInt::getMinValue(W), std::move(UMax));
Result = ConstantRange(APInt::getMinValue(W), std::move(UMax));
if (!Pred.hasSameSign())
return Result;
}
// For samesign, intersect with signed range.
LLVM_FALLTHROUGH;
case CmpInst::ICMP_SLT: {
APInt SMax(CR.getSignedMax());
if (SMax.isMinSignedValue())
return getEmpty(W);
return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax));
return Result.intersectWith(
ConstantRange(APInt::getSignedMinValue(W), std::move(SMax)));
}
case CmpInst::ICMP_ULE:
return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
case CmpInst::ICMP_ULE: {
Result = getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
if (!Pred.hasSameSign())
return Result;
}
// For samesign, intersect with signed range.
LLVM_FALLTHROUGH;
case CmpInst::ICMP_SLE:
return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
return Result.intersectWith(
getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1));
case CmpInst::ICMP_UGT: {
APInt UMin(CR.getUnsignedMin());
if (UMin.isMaxValue())
return getEmpty(W);
return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
Result = ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
if (!Pred.hasSameSign())
return Result;
}
// For samesign, intersect with signed range.
LLVM_FALLTHROUGH;
case CmpInst::ICMP_SGT: {
APInt SMin(CR.getSignedMin());
if (SMin.isMaxSignedValue())
return getEmpty(W);
return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W));
return Result.intersectWith(
ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W)));
}
case CmpInst::ICMP_UGE:
return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
case CmpInst::ICMP_UGE: {
Result = getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
if (!Pred.hasSameSign())
return Result;
}
// For samesign, intersect with signed range.
LLVM_FALLTHROUGH;
case CmpInst::ICMP_SGE:
return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
return Result.intersectWith(
getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W)));
}
}

Expand Down
100 changes: 100 additions & 0 deletions llvm/unittests/Analysis/ValueTrackingTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3234,6 +3234,106 @@ TEST_F(ValueTrackingTest, ComputeConstantRange) {
EXPECT_EQ(10, CR2.getUpper());
}

{
// Assumptions:
// * stride >= 5 (unsigned)
//
// stride = [5, 0)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(5, CR2.getLower());
EXPECT_EQ(0, CR2.getUpper());
}

{
// Assumptions:
// * stride > 5 (unsigned)
//
// stride = [5, 0)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp ugt i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(6, CR2.getLower());
EXPECT_EQ(0, CR2.getUpper());
}

{
// Assumptions:
// * stride >= 5 (samesign unsigned)
//
// stride = [5, MIN_SIGNED)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp samesign uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(5, CR2.getLower());
EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper());
}

{
// Assumptions:
// * stride > 5 (samesign unsigned)
//
// stride = [5, MIN_SIGNED)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp samesign ugt i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(6, CR2.getLower());
EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper());
}

{
// Assumptions:
// * stride >= 5
Expand Down
59 changes: 59 additions & 0 deletions llvm/unittests/IR/ConstantRangeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/IR/CmpPredicate.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/KnownBits.h"
Expand Down Expand Up @@ -1730,6 +1731,64 @@ TEST(ConstantRange, MakeAllowedICmpRegionEdgeCases) {
.isFullSet());
}

template <typename SIV>
auto getSameSignTester(SIV ShouldIncludeValue, CmpInst::Predicate Cmp) {
return [Cmp, ShouldIncludeValue](const ConstantRange &CR) {
uint32_t BitWidth = CR.getBitWidth();
unsigned Max = 1 << BitWidth;
SmallBitVector Elems(Max);
if (!CR.isEmptySet()) {
for (unsigned I : llvm::seq(Max)) {
APInt Current(BitWidth, I);
if (ShouldIncludeValue(Current, CR))
Elems.set(I);
}
}

CmpPredicate CmpPred(Cmp, true);
TestRange(ConstantRange::makeAllowedICmpRegion(CmpPred, CR), Elems,
PreferSmallest, {});
};
}

TEST(ConstantRange, MakeAllowedICmpRegionExaustive) {
EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative())
return A.sge(B.getSignedMin());
return A.uge(B.getUnsignedMin());
},
ICmpInst::ICMP_UGE));

EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative())
return A.sgt(B.getSignedMin());
return A.ugt(B.getUnsignedMin());
},
ICmpInst::ICMP_UGT));

EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative() && B.getUnsignedMax().isNegative())
return A.sle(B.getUnsignedMax());
if (A.isNonNegative() && B.getSignedMax().isNonNegative())
return A.ule(B.getSignedMax());
return false;
},
ICmpInst::ICMP_ULE));

EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative() && B.getUnsignedMax().isNegative())
return A.slt(B.getUnsignedMax());
if (A.isNonNegative() && B.getSignedMax().isNonNegative())
return A.ult(B.getSignedMax());
return false;
},
ICmpInst::ICMP_ULT));
}

TEST(ConstantRange, MakeExactICmpRegion) {
for (unsigned Bits : {1, 4}) {
EnumerateAPInts(Bits, [](const APInt &N) {
Expand Down