Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/IR/ConstantRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace llvm {

class MDNode;
class raw_ostream;
class CmpPredicate;
struct KnownBits;

/// This class represents a range of values.
Expand Down Expand Up @@ -106,7 +107,7 @@ class [[nodiscard]] ConstantRange {
///
/// Example: Pred = ult and Other = i8 [2, 5) returns Result = [0, 4)
LLVM_ABI static ConstantRange
makeAllowedICmpRegion(CmpInst::Predicate Pred, const ConstantRange &Other);
makeAllowedICmpRegion(CmpPredicate Pred, const ConstantRange &Other);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to simplify isImpliedCondCommonOperandWithCR:

if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
RPred))
return Res;
if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(LPred)
: LPred.dropSameSign();
RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(RPred)
: RPred.dropSameSign();
return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
RPred);
}
return std::nullopt;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no, we need to do the same to ConstantRange::icmp to simplify this


/// Produce the largest range such that all values in the returned range
/// satisfy the given predicate with all values contained within Other.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10325,7 +10325,7 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false,
UseInstrInfo, AC, I, DT, Depth + 1);
CR = CR.intersectWith(
ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS));
ConstantRange::makeAllowedICmpRegion(Cmp->getCmpPredicate(), RHS));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine there could be other users in the codebase that may start benefiting from using the whole CmpPredicate, so may be worth postponing this change in a follow-up?

}
}

Expand Down
60 changes: 57 additions & 3 deletions llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/IR/ConstantRange.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/CmpPredicate.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
Expand Down Expand Up @@ -106,7 +107,7 @@ std::pair<ConstantRange, ConstantRange> ConstantRange::splitPosNeg() const {
return {intersectWith(PosFilter), intersectWith(NegFilter)};
}

ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
ConstantRange ConstantRange::makeAllowedICmpRegion(CmpPredicate Pred,
const ConstantRange &CR) {
if (CR.isEmptySet())
return CR;
Expand All @@ -125,7 +126,26 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
APInt UMax(CR.getUnsignedMax());
if (UMax.isMinValue())
return getEmpty(W);
return ConstantRange(APInt::getMinValue(W), std::move(UMax));
if (!Pred.hasSameSign())
return ConstantRange(APInt::getMinValue(W), std::move(UMax));
if (W == 1)
return getEmpty(W);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment explaining that can never be satisfied for i1?


// deal with edge cases
APInt AugmentedUpper = CR.getUpper();
if (CR.getUnsignedMax().isMinSignedValue())
AugmentedUpper = APInt::getSignedMinValue(W);
else if (CR.getSignedMax().isMinValue())
AugmentedUpper = APInt::getMinValue(W);

if (AugmentedUpper == CR.getLower() && !CR.isFullSet())
return getEmpty(W);

ConstantRange Augmented(CR.getLower(), AugmentedUpper);
if (Augmented.isAllNonNegative() ||
(!Augmented.isAllNegative() && Augmented.isSignWrappedSet()))
return getNonEmpty(APInt::getMinValue(W), Augmented.getUnsignedMax());
return getNonEmpty(APInt::getSignedMinValue(W), Augmented.getSignedMax());
}
case CmpInst::ICMP_SLT: {
APInt SMax(CR.getSignedMax());
Expand All @@ -134,14 +154,42 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax));
}
case CmpInst::ICMP_ULE:
if (!Pred.hasSameSign())
return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
if (W == 1)
return CR;
if (CR.isAllNegative() ||
(!CR.isAllNonNegative() && !CR.isSignWrappedSet()))
return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1);
case CmpInst::ICMP_SLE:
return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1);
case CmpInst::ICMP_UGT: {
APInt UMin(CR.getUnsignedMin());
if (UMin.isMaxValue())
return getEmpty(W);
return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
if (!Pred.hasSameSign())
return ConstantRange(std::move(UMin) + 1, APInt::getZero(W));
if (W == 1)
return getEmpty(W);

// deal with edge cases
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be worth elaborating on the edge cases?

APInt AugmentedLower = CR.getLower();
if (CR.getLower().isMaxSignedValue())
AugmentedLower = APInt::getSignedMinValue(W);
else if (CR.getLower().isMaxValue())
AugmentedLower = APInt::getMinValue(W);

if (AugmentedLower == CR.getUpper())
return getEmpty(W);

ConstantRange Augmented(AugmentedLower, CR.getUpper());
if (Augmented.isAllNegative())
return getNonEmpty(Augmented.getSignedMin() + 1, APInt::getZero(W));
if (!Augmented.isAllNonNegative() && Augmented.isSignWrappedSet())
return getNonEmpty(Augmented.getUnsignedMin() + 1, APInt::getZero(W));
return getNonEmpty(Augmented.getSignedMin() + 1,
APInt::getSignedMinValue(W));
}
case CmpInst::ICMP_SGT: {
APInt SMin(CR.getSignedMin());
Expand All @@ -150,6 +198,12 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred,
return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W));
}
case CmpInst::ICMP_UGE:
if (!Pred.hasSameSign())
return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));

if (CR.isAllNonNegative() ||
(!CR.isAllNegative() && !CR.isSignWrappedSet()))
return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W));
case CmpInst::ICMP_SGE:
return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W));
Expand Down
100 changes: 100 additions & 0 deletions llvm/unittests/Analysis/ValueTrackingTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3234,6 +3234,106 @@ TEST_F(ValueTrackingTest, ComputeConstantRange) {
EXPECT_EQ(10, CR2.getUpper());
}

{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add an exhaustive test in ConstantRangeTest.cpp to make sure the result is optimal. I also doubt the correctness of Boolean constant ranges.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, there is an edge case for booleans

Added tests to check my changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For exhaustive tests, I mean you can use EnumerateInterestingConstantRanges to enumerate all possible ranges from i1 to i4. Then you use TestRange to check for correctness and optimality. If the optimality is infeasible, please make sure it is always smaller than the result without samesign.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I printed out the elements I create for each test to validate the correct elements are created, attached file with that print
res.txt

// Assumptions:
// * stride >= 5 (unsigned)
//
// stride = [5, 0)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(5, CR2.getLower());
EXPECT_EQ(0, CR2.getUpper());
}

{
// Assumptions:
// * stride > 5 (unsigned)
//
// stride = [5, 0)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp ugt i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(6, CR2.getLower());
EXPECT_EQ(0, CR2.getUpper());
}

{
// Assumptions:
// * stride >= 5 (samesign unsigned)
//
// stride = [5, MIN_SIGNED)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp samesign uge i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(5, CR2.getLower());
EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper());
}

{
// Assumptions:
// * stride > 5 (samesign unsigned)
//
// stride = [5, MIN_SIGNED)
auto M = parseModule(R"(
declare void @llvm.assume(i1)

define i32 @test(i32 %stride) {
%gt = icmp samesign ugt i32 %stride, 5
call void @llvm.assume(i1 %gt)
%stride.plus.one = add nsw nuw i32 %stride, 1
ret i32 %stride.plus.one
})");
Function *F = M->getFunction("test");

AssumptionCache AC(*F);
Value *Stride = &*F->arg_begin();

Instruction *I = &findInstructionByName(F, "stride.plus.one");
ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I);
EXPECT_EQ(6, CR2.getLower());
EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper());
}

{
// Assumptions:
// * stride >= 5
Expand Down
59 changes: 59 additions & 0 deletions llvm/unittests/IR/ConstantRangeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/IR/CmpPredicate.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/KnownBits.h"
Expand Down Expand Up @@ -1730,6 +1731,64 @@ TEST(ConstantRange, MakeAllowedICmpRegionEdgeCases) {
.isFullSet());
}

template <typename SIV>
auto getSameSignTester(SIV ShouldIncludeValue, CmpInst::Predicate Cmp) {
return [Cmp, ShouldIncludeValue](const ConstantRange &CR) {
uint32_t BitWidth = CR.getBitWidth();
unsigned Max = 1 << BitWidth;
SmallBitVector Elems(Max);
if (!CR.isEmptySet()) {
for (unsigned I : llvm::seq(Max)) {
APInt Current(BitWidth, I);
if (ShouldIncludeValue(Current, CR))
Elems.set(I);
}
}

CmpPredicate CmpPred(Cmp, true);
TestRange(ConstantRange::makeAllowedICmpRegion(CmpPred, CR), Elems,
PreferSmallest, {});
};
}

TEST(ConstantRange, MakeAllowedICmpRegionExaustive) {
EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative())
return A.sge(B.getSignedMin());
return A.uge(B.getUnsignedMin());
},
ICmpInst::ICMP_UGE));

EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative())
return A.sgt(B.getSignedMin());
return A.ugt(B.getUnsignedMin());
},
ICmpInst::ICMP_UGT));

EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative() && B.getUnsignedMax().isNegative())
return A.sle(B.getUnsignedMax());
if (A.isNonNegative() && B.getSignedMax().isNonNegative())
return A.ule(B.getSignedMax());
return false;
},
ICmpInst::ICMP_ULE));

EnumerateInterestingConstantRanges(getSameSignTester(
[](const APInt &A, const ConstantRange &B) {
if (A.isNegative() && B.getUnsignedMax().isNegative())
return A.slt(B.getUnsignedMax());
if (A.isNonNegative() && B.getSignedMax().isNonNegative())
return A.ult(B.getSignedMax());
return false;
},
ICmpInst::ICMP_ULT));
}

TEST(ConstantRange, MakeExactICmpRegion) {
for (unsigned Bits : {1, 4}) {
EnumerateAPInts(Bits, [](const APInt &N) {
Expand Down