-
Notifications
You must be signed in to change notification settings - Fork 16.4k
[ConstantRange] Expand makeAllowedICmpRegion to use samesign to give tighter range #174355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
dca9fb2
94734c3
a361009
8f3757d
1c98459
004eb60
210546d
3727dd1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10325,7 +10325,7 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned, | |
| computeConstantRange(Cmp->getOperand(1), /* ForSigned */ false, | ||
| UseInstrInfo, AC, I, DT, Depth + 1); | ||
| CR = CR.intersectWith( | ||
| ConstantRange::makeAllowedICmpRegion(Cmp->getPredicate(), RHS)); | ||
| ConstantRange::makeAllowedICmpRegion(Cmp->getCmpPredicate(), RHS)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I imagine there could be other users in the codebase that may start benefiting from using the whole CmpPredicate, so may be worth postponing this change in a follow-up? |
||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| #include "llvm/IR/ConstantRange.h" | ||
| #include "llvm/ADT/APInt.h" | ||
| #include "llvm/Config/llvm-config.h" | ||
| #include "llvm/IR/CmpPredicate.h" | ||
| #include "llvm/IR/Constants.h" | ||
| #include "llvm/IR/InstrTypes.h" | ||
| #include "llvm/IR/Instruction.h" | ||
|
|
@@ -106,7 +107,7 @@ std::pair<ConstantRange, ConstantRange> ConstantRange::splitPosNeg() const { | |
| return {intersectWith(PosFilter), intersectWith(NegFilter)}; | ||
| } | ||
|
|
||
| ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, | ||
| ConstantRange ConstantRange::makeAllowedICmpRegion(CmpPredicate Pred, | ||
| const ConstantRange &CR) { | ||
| if (CR.isEmptySet()) | ||
| return CR; | ||
|
|
@@ -125,7 +126,26 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, | |
| APInt UMax(CR.getUnsignedMax()); | ||
| if (UMax.isMinValue()) | ||
| return getEmpty(W); | ||
| return ConstantRange(APInt::getMinValue(W), std::move(UMax)); | ||
| if (!Pred.hasSameSign()) | ||
| return ConstantRange(APInt::getMinValue(W), std::move(UMax)); | ||
| if (W == 1) | ||
| return getEmpty(W); | ||
|
||
|
|
||
| // deal with edge cases | ||
artagnon marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| APInt AugmentedUpper = CR.getUpper(); | ||
artagnon marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (CR.getUnsignedMax().isMinSignedValue()) | ||
| AugmentedUpper = APInt::getSignedMinValue(W); | ||
| else if (CR.getSignedMax().isMinValue()) | ||
| AugmentedUpper = APInt::getMinValue(W); | ||
|
|
||
| if (AugmentedUpper == CR.getLower() && !CR.isFullSet()) | ||
| return getEmpty(W); | ||
|
|
||
| ConstantRange Augmented(CR.getLower(), AugmentedUpper); | ||
| if (Augmented.isAllNonNegative() || | ||
| (!Augmented.isAllNegative() && Augmented.isSignWrappedSet())) | ||
| return getNonEmpty(APInt::getMinValue(W), Augmented.getUnsignedMax()); | ||
| return getNonEmpty(APInt::getSignedMinValue(W), Augmented.getSignedMax()); | ||
| } | ||
| case CmpInst::ICMP_SLT: { | ||
| APInt SMax(CR.getSignedMax()); | ||
|
|
@@ -134,14 +154,42 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, | |
| return ConstantRange(APInt::getSignedMinValue(W), std::move(SMax)); | ||
| } | ||
| case CmpInst::ICMP_ULE: | ||
| if (!Pred.hasSameSign()) | ||
| return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1); | ||
| if (W == 1) | ||
| return CR; | ||
| if (CR.isAllNegative() || | ||
| (!CR.isAllNonNegative() && !CR.isSignWrappedSet())) | ||
| return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1); | ||
| return getNonEmpty(APInt::getMinValue(W), CR.getUnsignedMax() + 1); | ||
| case CmpInst::ICMP_SLE: | ||
| return getNonEmpty(APInt::getSignedMinValue(W), CR.getSignedMax() + 1); | ||
| case CmpInst::ICMP_UGT: { | ||
| APInt UMin(CR.getUnsignedMin()); | ||
| if (UMin.isMaxValue()) | ||
artagnon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return getEmpty(W); | ||
| return ConstantRange(std::move(UMin) + 1, APInt::getZero(W)); | ||
| if (!Pred.hasSameSign()) | ||
| return ConstantRange(std::move(UMin) + 1, APInt::getZero(W)); | ||
| if (W == 1) | ||
| return getEmpty(W); | ||
|
|
||
| // deal with edge cases | ||
|
||
| APInt AugmentedLower = CR.getLower(); | ||
| if (CR.getLower().isMaxSignedValue()) | ||
| AugmentedLower = APInt::getSignedMinValue(W); | ||
| else if (CR.getLower().isMaxValue()) | ||
| AugmentedLower = APInt::getMinValue(W); | ||
|
|
||
| if (AugmentedLower == CR.getUpper()) | ||
| return getEmpty(W); | ||
|
|
||
| ConstantRange Augmented(AugmentedLower, CR.getUpper()); | ||
| if (Augmented.isAllNegative()) | ||
| return getNonEmpty(Augmented.getSignedMin() + 1, APInt::getZero(W)); | ||
| if (!Augmented.isAllNonNegative() && Augmented.isSignWrappedSet()) | ||
| return getNonEmpty(Augmented.getUnsignedMin() + 1, APInt::getZero(W)); | ||
| return getNonEmpty(Augmented.getSignedMin() + 1, | ||
| APInt::getSignedMinValue(W)); | ||
| } | ||
| case CmpInst::ICMP_SGT: { | ||
| APInt SMin(CR.getSignedMin()); | ||
|
|
@@ -150,6 +198,9 @@ ConstantRange ConstantRange::makeAllowedICmpRegion(CmpInst::Predicate Pred, | |
| return ConstantRange(std::move(SMin) + 1, APInt::getSignedMinValue(W)); | ||
| } | ||
| case CmpInst::ICMP_UGE: | ||
| if (Pred.hasSameSign() && (CR.isAllNonNegative() || | ||
| (!CR.isAllNegative() && !CR.isSignWrappedSet()))) | ||
| return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W)); | ||
| return getNonEmpty(CR.getUnsignedMin(), APInt::getZero(W)); | ||
| case CmpInst::ICMP_SGE: | ||
| return getNonEmpty(CR.getSignedMin(), APInt::getSignedMinValue(W)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3234,6 +3234,106 @@ TEST_F(ValueTrackingTest, ComputeConstantRange) { | |
| EXPECT_EQ(10, CR2.getUpper()); | ||
| } | ||
|
|
||
| { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also add an exhaustive test in ConstantRangeTest.cpp to make sure the result is optimal. I also doubt the correctness of Boolean constant ranges.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch, there is an edge case for booleans Added tests to check my changes
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For exhaustive tests, I mean you can use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I printed out the elements I create for each test to validate the correct elements are created, attached file with that print |
||
| // Assumptions: | ||
| // * stride >= 5 (unsigned) | ||
| // | ||
| // stride = [5, 0) | ||
| auto M = parseModule(R"( | ||
| declare void @llvm.assume(i1) | ||
|
|
||
| define i32 @test(i32 %stride) { | ||
| %gt = icmp uge i32 %stride, 5 | ||
artagnon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| call void @llvm.assume(i1 %gt) | ||
| %stride.plus.one = add nsw nuw i32 %stride, 1 | ||
| ret i32 %stride.plus.one | ||
| })"); | ||
| Function *F = M->getFunction("test"); | ||
|
|
||
| AssumptionCache AC(*F); | ||
| Value *Stride = &*F->arg_begin(); | ||
|
|
||
| Instruction *I = &findInstructionByName(F, "stride.plus.one"); | ||
| ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); | ||
| EXPECT_EQ(5, CR2.getLower()); | ||
| EXPECT_EQ(0, CR2.getUpper()); | ||
| } | ||
|
|
||
| { | ||
| // Assumptions: | ||
| // * stride > 5 (unsigned) | ||
| // | ||
| // stride = [5, 0) | ||
| auto M = parseModule(R"( | ||
| declare void @llvm.assume(i1) | ||
|
|
||
| define i32 @test(i32 %stride) { | ||
| %gt = icmp ugt i32 %stride, 5 | ||
| call void @llvm.assume(i1 %gt) | ||
| %stride.plus.one = add nsw nuw i32 %stride, 1 | ||
| ret i32 %stride.plus.one | ||
| })"); | ||
| Function *F = M->getFunction("test"); | ||
|
|
||
| AssumptionCache AC(*F); | ||
| Value *Stride = &*F->arg_begin(); | ||
|
|
||
| Instruction *I = &findInstructionByName(F, "stride.plus.one"); | ||
| ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); | ||
| EXPECT_EQ(6, CR2.getLower()); | ||
| EXPECT_EQ(0, CR2.getUpper()); | ||
| } | ||
|
|
||
| { | ||
| // Assumptions: | ||
| // * stride >= 5 (samesign unsigned) | ||
| // | ||
| // stride = [5, MIN_SIGNED) | ||
| auto M = parseModule(R"( | ||
| declare void @llvm.assume(i1) | ||
|
|
||
| define i32 @test(i32 %stride) { | ||
| %gt = icmp samesign uge i32 %stride, 5 | ||
| call void @llvm.assume(i1 %gt) | ||
| %stride.plus.one = add nsw nuw i32 %stride, 1 | ||
| ret i32 %stride.plus.one | ||
| })"); | ||
| Function *F = M->getFunction("test"); | ||
|
|
||
| AssumptionCache AC(*F); | ||
| Value *Stride = &*F->arg_begin(); | ||
|
|
||
| Instruction *I = &findInstructionByName(F, "stride.plus.one"); | ||
| ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); | ||
| EXPECT_EQ(5, CR2.getLower()); | ||
| EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper()); | ||
| } | ||
|
|
||
| { | ||
| // Assumptions: | ||
| // * stride > 5 (samesign unsigned) | ||
| // | ||
| // stride = [5, MIN_SIGNED) | ||
| auto M = parseModule(R"( | ||
| declare void @llvm.assume(i1) | ||
|
|
||
| define i32 @test(i32 %stride) { | ||
| %gt = icmp samesign ugt i32 %stride, 5 | ||
| call void @llvm.assume(i1 %gt) | ||
| %stride.plus.one = add nsw nuw i32 %stride, 1 | ||
| ret i32 %stride.plus.one | ||
| })"); | ||
| Function *F = M->getFunction("test"); | ||
|
|
||
| AssumptionCache AC(*F); | ||
| Value *Stride = &*F->arg_begin(); | ||
|
|
||
| Instruction *I = &findInstructionByName(F, "stride.plus.one"); | ||
| ConstantRange CR2 = computeConstantRange(Stride, false, true, &AC, I); | ||
| EXPECT_EQ(6, CR2.getLower()); | ||
| EXPECT_EQ(APInt::getSignedMinValue(32), CR2.getUpper()); | ||
| } | ||
|
|
||
| { | ||
| // Assumptions: | ||
| // * stride >= 5 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect this to simplify isImpliedCondCommonOperandWithCR:
llvm-project/llvm/lib/Analysis/ValueTracking.cpp
Lines 9456 to 9467 in d5a5678
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately no, we need to do the same to ConstantRange::icmp to simplify this