Skip to content

Commit

Permalink
Enhance unionBoundingBox utility
Browse files Browse the repository at this point in the history
Enhance `unionBoundingBox` utility to work with input
constraints having local variables.
  • Loading branch information
arnab-polymage committed Feb 25, 2025
1 parent 03cb46d commit daaad70
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 48 deletions.
9 changes: 4 additions & 5 deletions mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
bool areVarsAlignedWithOther(const FlatLinearConstraints &other);

/// Updates the constraints to be the smallest bounding (enclosing) box that
/// contains the points of `this` set and that of `other`, with the symbols
/// being treated specially. For each of the dimensions, the min of the lower
/// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
/// to determine such a bounding box. `other` is expected to have the same
/// dimensional variables as this constraint system (in the same order).
/// contains the points of `this` set and that of `other`. For each of the
/// dimensions, the min of the lower bounds and the max of the upper bounds is
/// computed to determine such a bounding box. `other` is expected to have the
/// same dimensional variables as this constraint system (in the same order).
///
/// E.g.:
/// 1) this = {0 <= d0 <= 127},
Expand Down
24 changes: 11 additions & 13 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,11 +489,10 @@ class IntegerRelation {
void constantFoldVarRange(unsigned pos, unsigned num);

/// Updates the constraints to be the smallest bounding (enclosing) box that
/// contains the points of `this` set and that of `other`, with the symbols
/// being treated specially. For each of the dimensions, the min of the lower
/// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
/// to determine such a bounding box. `other` is expected to have the same
/// dimensional variables as this constraint system (in the same order).
/// contains the points of `this` set and that of `other`. For each of the
/// dimensions, the min of the lower bounds and the max of the upper bounds is
/// computed to determine such a bounding box. `other` is expected to have the
/// same dimensional variables as this constraint system (in the same order).
///
/// E.g.:
/// 1) this = {0 <= d0 <= 127},
Expand All @@ -512,14 +511,13 @@ class IntegerRelation {
/// than or equal to 'exclusive upper bound' - 'lower bound' of the
/// variable. This constant bound is guaranteed to be non-negative. Returns
/// std::nullopt if it's not a constant. This method employs trivial (low
/// complexity / cost) checks and detection. Symbolic variables are treated
/// specially, i.e., it looks for constant differences between affine
/// expressions involving only the symbolic variables. `lb` and `ub` (along
/// with the `boundFloorDivisor`) are set to represent the lower and upper
/// bound associated with the constant difference: `lb`, `ub` have the
/// coefficients, and `boundFloorDivisor`, their divisor. `minLbPos` and
/// `minUbPos` if non-null are set to the position of the constant lower bound
/// and upper bound respectively (to the same if they are from an
/// complexity / cost) checks and detection. It looks for constant differences
/// between affine expressions involving symbolic and local variables. `lb`
/// and `ub` (along with the `boundFloorDivisor`) are set to represent the
/// lower and upper bound associated with the constant difference: `lb`, `ub`
/// have the coefficients, and `boundFloorDivisor`, their divisor. `minLbPos`
/// and `minUbPos` if non-null are set to the position of the constant lower
/// bound and upper bound respectively (to the same if they are from an
/// equality). Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a
/// system with three symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See
/// comments at function definition for examples.
Expand Down
2 changes: 0 additions & 2 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,8 +1303,6 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
otherMaybeValues.begin(),
otherMaybeValues.begin() + getNumDimVars()) &&
"dim values mismatch");
assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
assert(getNumLocalVars() == 0 && "local vars not supported yet here");

// Align `other` to this.
if (!areVarsAligned(*this, otherCst)) {
Expand Down
46 changes: 18 additions & 28 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,13 +1578,11 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {

/// Returns a non-negative constant bound on the extent (upper bound - lower
/// bound) of the specified variable if it is found to be a constant; returns
/// std::nullopt if it's not a constant. This methods treats symbolic variables
/// specially, i.e., it looks for constant differences between affine
/// expressions involving only the symbolic variables. See comments at function
/// definition for example. 'lb', if provided, is set to the lower bound
/// associated with the constant difference. Note that 'lb' is purely symbolic
/// and thus will contain the coefficients of the symbolic variables and the
/// constant coefficient.
/// std::nullopt if it's not a constant. This methods looks for constant
/// differences between affine expressions. See comments at function definition
/// for example. 'lb', if provided, is set to the lower bound associated with
/// the constant difference. `lb' will contain the coefficients of the symbolic
/// variables, local variables and the constant coefficient.
// Egs: 0 <= i <= 15, return 16.
// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
Expand All @@ -1600,22 +1598,15 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
// of the symbolic variables (+ constant).
int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
if (eqPos != -1) {
auto eq = getEquality(eqPos);
// If the equality involves a local var, punt for now.
// TODO: this can be handled in the future by using the explicit
// representation of the local vars.
if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
[](const DynamicAPInt &coeff) { return coeff == 0; }))
return std::nullopt;

// This variable can only take a single value.
if (lb) {
// Set lb to that symbolic value.
lb->resize(getNumSymbolVars() + 1);
lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
if (ub)
ub->resize(getNumSymbolVars() + 1);
for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
DynamicAPInt v = atEq(eqPos, pos);
ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
for (unsigned c = 0, f = getNumSymbolVars() + getNumLocalVars() + 1;
c < f; c++) {
MPInt v = atEq(eqPos, pos);
// atEq(eqRow, pos) is either -1 or 1.
assert(v * v == 1);
(*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
Expand Down Expand Up @@ -1687,27 +1678,30 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
}
if (lb && minDiff) {
// Set lb to the symbolic lower bound.
lb->resize(getNumSymbolVars() + 1);
lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
if (ub)
ub->resize(getNumSymbolVars() + 1);
ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
// The lower bound is the ceildiv of the lb constraint over the coefficient
// of the variable at 'pos'. We express the ceildiv equivalently as a floor
// for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
// 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
*boundFloorDivisor = atIneq(minLbPosition, pos);
assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1; c < e;
c++) {
(*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
}
if (ub) {
for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1;
c < e; c++)
(*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
}
// The lower bound leads to a ceildiv while the upper bound is a floordiv
// whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
// d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
// the constant term for the lower bound.
(*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
(*lb)[getNumSymbolVars() + getNumLocalVars()] +=
atIneq(minLbPosition, pos) - 1;
}
if (minLbPos)
*minLbPos = minLbPosition;
Expand Down Expand Up @@ -2180,8 +2174,6 @@ static void getCommonConstraints(const IntegerRelation &a,
LogicalResult
IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
assert(getNumLocalVars() == 0 && "local ids not supported yet here");

// Get the constraints common to both systems; these will be added as is to
// the union.
IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
Expand Down Expand Up @@ -2211,11 +2203,9 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
auto otherExtent = otherCst.getConstantBoundOnDimSize(
d, &otherLb, &otherLbFloorDivisor, &otherUb);
if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
// TODO: symbolic extents when necessary.
return failure();

assert(lbFloorDivisor > 0 && "divisor always expected to be positive");

auto res = compareBounds(lb, otherLb);
// Identify min.
if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {
Expand Down
14 changes: 14 additions & 0 deletions mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
}

// Test union of two integer relations if they have local variable(s).
TEST(IntegerRelationTest, unionBoundingBox) {
IntegerRelation relA = parseRelationFromSet(
"(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - N - x"
">= 0, x + y + z floordiv 6 == 0)",
1);
IntegerRelation relB = parseRelationFromSet(
"(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - M - x"
">= 0, x + y + z floordiv 7 == 0)",
1);
assert(relA.getNumLocalVars() > 0);
EXPECT_TRUE(relA.unionBoundingBox(relB).succeeded());
}

0 comments on commit daaad70

Please sign in to comment.