diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h index c8167014b5300..15387201affa8 100644 --- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h +++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h @@ -474,11 +474,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints { bool areVarsAlignedWithOther(const FlatLinearConstraints &other); /// Updates the constraints to be the smallest bounding (enclosing) box that - /// contains the points of `this` set and that of `other`, with the symbols - /// being treated specially. For each of the dimensions, the min of the lower - /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed - /// to determine such a bounding box. `other` is expected to have the same - /// dimensional variables as this constraint system (in the same order). + /// contains the points of `this` set and that of `other`. For each of the + /// dimensions, the min of the lower bounds and the max of the upper bounds is + /// computed to determine such a bounding box. `other` is expected to have the + /// same dimensional variables as this constraint system (in the same order). /// /// E.g.: /// 1) this = {0 <= d0 <= 127}, diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index ddc18038e869c..ae45743ecc1be 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -489,11 +489,10 @@ class IntegerRelation { void constantFoldVarRange(unsigned pos, unsigned num); /// Updates the constraints to be the smallest bounding (enclosing) box that - /// contains the points of `this` set and that of `other`, with the symbols - /// being treated specially. For each of the dimensions, the min of the lower - /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed - /// to determine such a bounding box. `other` is expected to have the same - /// dimensional variables as this constraint system (in the same order). + /// contains the points of `this` set and that of `other`. For each of the + /// dimensions, the min of the lower bounds and the max of the upper bounds is + /// computed to determine such a bounding box. `other` is expected to have the + /// same dimensional variables as this constraint system (in the same order). /// /// E.g.: /// 1) this = {0 <= d0 <= 127}, @@ -512,14 +511,13 @@ class IntegerRelation { /// than or equal to 'exclusive upper bound' - 'lower bound' of the /// variable. This constant bound is guaranteed to be non-negative. Returns /// std::nullopt if it's not a constant. This method employs trivial (low - /// complexity / cost) checks and detection. Symbolic variables are treated - /// specially, i.e., it looks for constant differences between affine - /// expressions involving only the symbolic variables. `lb` and `ub` (along - /// with the `boundFloorDivisor`) are set to represent the lower and upper - /// bound associated with the constant difference: `lb`, `ub` have the - /// coefficients, and `boundFloorDivisor`, their divisor. `minLbPos` and - /// `minUbPos` if non-null are set to the position of the constant lower bound - /// and upper bound respectively (to the same if they are from an + /// complexity / cost) checks and detection. It looks for constant differences + /// between affine expressions involving symbolic and local variables. `lb` + /// and `ub` (along with the `boundFloorDivisor`) are set to represent the + /// lower and upper bound associated with the constant difference: `lb`, `ub` + /// have the coefficients, and `boundFloorDivisor`, their divisor. `minLbPos` + /// and `minUbPos` if non-null are set to the position of the constant lower + /// bound and upper bound respectively (to the same if they are from an /// equality). Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a /// system with three symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See /// comments at function definition for examples. diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 4653eca9887ce..ae9f9acd89c2e 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1303,8 +1303,6 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox( otherMaybeValues.begin(), otherMaybeValues.begin() + getNumDimVars()) && "dim values mismatch"); - assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here"); - assert(getNumLocalVars() == 0 && "local vars not supported yet here"); // Align `other` to this. if (!areVarsAligned(*this, otherCst)) { diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 74cdf567c0e56..89d3a936e8e9e 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1578,13 +1578,11 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) { /// Returns a non-negative constant bound on the extent (upper bound - lower /// bound) of the specified variable if it is found to be a constant; returns -/// std::nullopt if it's not a constant. This methods treats symbolic variables -/// specially, i.e., it looks for constant differences between affine -/// expressions involving only the symbolic variables. See comments at function -/// definition for example. 'lb', if provided, is set to the lower bound -/// associated with the constant difference. Note that 'lb' is purely symbolic -/// and thus will contain the coefficients of the symbolic variables and the -/// constant coefficient. +/// std::nullopt if it's not a constant. This methods looks for constant +/// differences between affine expressions. See comments at function definition +/// for example. 'lb', if provided, is set to the lower bound associated with +/// the constant difference. `lb' will contain the coefficients of the symbolic +/// variables, local variables and the constant coefficient. // Egs: 0 <= i <= 15, return 16. // s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol) // s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16. @@ -1600,22 +1598,15 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( // of the symbolic variables (+ constant). int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true); if (eqPos != -1) { - auto eq = getEquality(eqPos); - // If the equality involves a local var, punt for now. - // TODO: this can be handled in the future by using the explicit - // representation of the local vars. - if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1, - [](const DynamicAPInt &coeff) { return coeff == 0; })) - return std::nullopt; - // This variable can only take a single value. if (lb) { // Set lb to that symbolic value. - lb->resize(getNumSymbolVars() + 1); + lb->resize(getNumSymbolVars() + getNumLocalVars() + 1); if (ub) - ub->resize(getNumSymbolVars() + 1); - for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) { - DynamicAPInt v = atEq(eqPos, pos); + ub->resize(getNumSymbolVars() + getNumLocalVars() + 1); + for (unsigned c = 0, f = getNumSymbolVars() + getNumLocalVars() + 1; + c < f; c++) { + MPInt v = atEq(eqPos, pos); // atEq(eqRow, pos) is either -1 or 1. assert(v * v == 1); (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v @@ -1687,27 +1678,30 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( } if (lb && minDiff) { // Set lb to the symbolic lower bound. - lb->resize(getNumSymbolVars() + 1); + lb->resize(getNumSymbolVars() + getNumLocalVars() + 1); if (ub) - ub->resize(getNumSymbolVars() + 1); + ub->resize(getNumSymbolVars() + getNumLocalVars() + 1); // The lower bound is the ceildiv of the lb constraint over the coefficient // of the variable at 'pos'. We express the ceildiv equivalently as a floor // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N + // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32). *boundFloorDivisor = atIneq(minLbPosition, pos); assert(*boundFloorDivisor == -atIneq(minUbPosition, pos)); - for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) { + for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1; c < e; + c++) { (*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c); } if (ub) { - for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) + for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1; + c < e; c++) (*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c); } // The lower bound leads to a ceildiv while the upper bound is a floordiv // whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val + // d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to // the constant term for the lower bound. - (*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1; + (*lb)[getNumSymbolVars() + getNumLocalVars()] += + atIneq(minLbPosition, pos) - 1; } if (minLbPos) *minLbPos = minLbPosition; @@ -2180,8 +2174,6 @@ static void getCommonConstraints(const IntegerRelation &a, LogicalResult IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { assert(space.isEqual(otherCst.getSpace()) && "Spaces should match."); - assert(getNumLocalVars() == 0 && "local ids not supported yet here"); - // Get the constraints common to both systems; these will be added as is to // the union. IntegerRelation commonCst(PresburgerSpace::getRelationSpace()); @@ -2211,11 +2203,9 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { auto otherExtent = otherCst.getConstantBoundOnDimSize( d, &otherLb, &otherLbFloorDivisor, &otherUb); if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor) - // TODO: symbolic extents when necessary. return failure(); assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); - auto res = compareBounds(lb, otherLb); // Identify min. if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) { diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index 7df500bc9568a..44c67a301b110 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) { EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3])); EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4])); } + +// Test union of two integer relations if they have local variable(s). +TEST(IntegerRelationTest, unionBoundingBox) { + IntegerRelation relA = parseRelationFromSet( + "(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - N - x" + ">= 0, x + y + z floordiv 6 == 0)", + 1); + IntegerRelation relB = parseRelationFromSet( + "(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - M - x" + ">= 0, x + y + z floordiv 7 == 0)", + 1); + assert(relA.getNumLocalVars() > 0); + EXPECT_TRUE(relA.unionBoundingBox(relB).succeeded()); +}