From afbf57e713e07424101464a654c64a47e44bbc80 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Wed, 3 Jul 2024 10:51:25 +0100 Subject: [PATCH] mlir/Presburger: reinstate use of LogicalResult (#97415) Follow up on a desire post-landing d0fee98 (mlir/Presburger: strip dependency on MLIRSupport) to reinstate the use of LogicalResult in Presburger. Since db791b2 (mlir/LogicalResult: move into llvm), LogicalResult is in LLVM, and fulfilling this desire is possible while still maintaining the goal of stripping the Presburger library of mlir dependencies. --- .../Analysis/Presburger/IntegerRelation.h | 12 ++-- .../mlir/Analysis/Presburger/Simplex.h | 12 ++-- .../Analysis/FlatLinearValueConstraints.cpp | 4 +- .../Analysis/Presburger/IntegerRelation.cpp | 26 ++++---- .../Presburger/PresburgerRelation.cpp | 59 ++++++++++--------- mlir/lib/Analysis/Presburger/Simplex.cpp | 57 +++++++++--------- mlir/lib/Analysis/Presburger/Utils.cpp | 30 +++++----- 7 files changed, 106 insertions(+), 94 deletions(-) diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index 5e5cd898b75189..a27fc8c37eeda1 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -21,13 +21,17 @@ #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/DynamicAPInt.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" #include namespace mlir { namespace presburger { using llvm::DynamicAPInt; +using llvm::failure; using llvm::int64fromDynamicAPInt; +using llvm::LogicalResult; using llvm::SmallVectorImpl; +using llvm::success; class IntegerRelation; class IntegerPolyhedron; @@ -478,7 +482,7 @@ class IntegerRelation { /// equality detection; if successful, the constant is substituted for the /// variable everywhere in the constraint system and then removed from the /// system. - bool constantFoldVar(unsigned pos); + LogicalResult constantFoldVar(unsigned pos); /// This method calls `constantFoldVar` for the specified range of variables, /// `num` variables starting at position `pos`. @@ -501,7 +505,7 @@ class IntegerRelation { /// 3) this = {0 <= d0 <= 5, 1 <= d1 <= 9} /// other = {2 <= d0 <= 6, 5 <= d1 <= 15}, /// output = {0 <= d0 <= 6, 1 <= d1 <= 15} - bool unionBoundingBox(const IntegerRelation &other); + LogicalResult unionBoundingBox(const IntegerRelation &other); /// Returns the smallest known constant bound for the extent of the specified /// variable (pos^th), i.e., the smallest known constant that is greater @@ -774,8 +778,8 @@ class IntegerRelation { /// Eliminates a single variable at `position` from equality and inequality /// constraints. Returns `success` if the variable was eliminated, and /// `failure` otherwise. - inline bool gaussianEliminateVar(unsigned position) { - return gaussianEliminateVars(position, position + 1) == 1; + inline LogicalResult gaussianEliminateVar(unsigned position) { + return success(gaussianEliminateVars(position, position + 1) == 1); } /// Removes local variables using equalities. Each equality is checked if it diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h index f413636e06910e..4c40c4cdcb655a 100644 --- a/mlir/include/mlir/Analysis/Presburger/Simplex.h +++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h @@ -445,7 +445,7 @@ class LexSimplexBase : public SimplexBase { /// lexicopositivity of the basis transform. The row must have a non-positive /// sample value. If this is not possible, return failure. This occurs when /// the constraints have no solution or the sample value is zero. - bool moveRowUnknownToColumn(unsigned row); + LogicalResult moveRowUnknownToColumn(unsigned row); /// Given a row that has a non-integer sample value, add an inequality to cut /// away this fractional sample value from the polytope without removing any @@ -459,7 +459,7 @@ class LexSimplexBase : public SimplexBase { /// /// Return failure if the tableau became empty, and success if it didn't. /// Failure status indicates that the polytope was integer empty. - bool addCut(unsigned row); + LogicalResult addCut(unsigned row); /// Undo the addition of the last constraint. This is only called while /// rolling back. @@ -511,7 +511,7 @@ class LexSimplex : public LexSimplexBase { MaybeOptimum> getRationalSample() const; /// Make the tableau configuration consistent. - bool restoreRationalConsistency(); + LogicalResult restoreRationalConsistency(); /// Return whether the specified row is violated; bool rowIsViolated(unsigned row) const; @@ -626,7 +626,7 @@ class SymbolicLexSimplex : public LexSimplexBase { /// Return failure if the tableau became empty, indicating that the polytope /// is always integer empty in the current symbol domain. /// Return success otherwise. - bool doNonBranchingPivots(); + LogicalResult doNonBranchingPivots(); /// Get a row that is always violated in the current domain, if one exists. std::optional maybeGetAlwaysViolatedRow(); @@ -647,7 +647,7 @@ class SymbolicLexSimplex : public LexSimplexBase { /// at the time of the call. (This function may modify the symbol domain, but /// failure statu indicates that the polytope was empty for all symbol values /// in the initial domain.) - bool addSymbolicCut(unsigned row); + LogicalResult addSymbolicCut(unsigned row); /// Get the numerator of the symbolic sample of the specific row. /// This is an affine expression in the symbols with integer coefficients. @@ -820,7 +820,7 @@ class Simplex : public SimplexBase { /// /// Returns success if the unknown was successfully restored to a non-negative /// sample value, failure otherwise. - bool restoreRow(Unknown &u); + LogicalResult restoreRow(Unknown &u); /// Find a pivot to change the sample value of row in the specified /// direction while preserving tableau consistency, except that if the diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index 746cff525beb27..e628fb152b52f8 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -1247,10 +1247,10 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox( if (!areVarsAligned(*this, otherCst)) { FlatLinearValueConstraints otherCopy(otherCst); mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy); - return success(IntegerPolyhedron::unionBoundingBox(otherCopy)); + return IntegerPolyhedron::unionBoundingBox(otherCopy); } - return success(IntegerPolyhedron::unionBoundingBox(otherCst)); + return IntegerPolyhedron::unionBoundingBox(otherCst); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 6b438692ff6f91..095a7dcb287f3c 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -1552,22 +1553,22 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos, return -1; } -bool IntegerRelation::constantFoldVar(unsigned pos) { +LogicalResult IntegerRelation::constantFoldVar(unsigned pos) { assert(pos < getNumVars() && "invalid position"); int rowIdx; if ((rowIdx = findEqualityToConstant(*this, pos)) == -1) - return false; + return failure(); // atEq(rowIdx, pos) is either -1 or 1. assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1); DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos); setAndEliminate(pos, constVal); - return true; + return success(); } void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) { for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) { - if (!constantFoldVar(t)) + if (constantFoldVar(t).failed()) t++; } } @@ -1944,9 +1945,9 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow, for (unsigned r = 0, e = getNumEqualities(); r < e; r++) { if (atEq(r, pos) != 0) { // Use Gaussian elimination here (since we have an equality). - bool ret = gaussianEliminateVar(pos); + LogicalResult ret = gaussianEliminateVar(pos); (void)ret; - assert(ret && "Gaussian elimination guaranteed to succeed"); + assert(ret.succeeded() && "Gaussian elimination guaranteed to succeed"); LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n"); LLVM_DEBUG(dump()); return; @@ -2173,7 +2174,8 @@ static void getCommonConstraints(const IntegerRelation &a, // Computes the bounding box with respect to 'other' by finding the min of the // lower bounds and the max of the upper bounds along each of the dimensions. -bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { +LogicalResult +IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { assert(space.isEqual(otherCst.getSpace()) && "Spaces should match."); assert(getNumLocalVars() == 0 && "local ids not supported yet here"); @@ -2201,13 +2203,13 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { if (!extent.has_value()) // TODO: symbolic extents when necessary. // TODO: handle union if a dimension is unbounded. - return false; + return failure(); auto otherExtent = otherCst.getConstantBoundOnDimSize( d, &otherLb, &otherLbFloorDivisor, &otherUb); if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor) // TODO: symbolic extents when necessary. - return false; + return failure(); assert(lbFloorDivisor > 0 && "divisor always expected to be positive"); @@ -2227,7 +2229,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { auto constLb = getConstantBound(BoundType::LB, d); auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d); if (!constLb.has_value() || !constOtherLb.has_value()) - return false; + return failure(); std::fill(minLb.begin(), minLb.end(), 0); minLb.back() = std::min(*constLb, *constOtherLb); } @@ -2243,7 +2245,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { auto constUb = getConstantBound(BoundType::UB, d); auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d); if (!constUb.has_value() || !constOtherUb.has_value()) - return false; + return failure(); std::fill(maxUb.begin(), maxUb.end(), 0); maxUb.back() = std::max(*constUb, *constOtherUb); } @@ -2281,7 +2283,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) { // union (since the above are just the union along dimensions); we shouldn't // be discarding any other constraints on the symbols. - return true; + return success(); } bool IntegerRelation::isColZero(unsigned pos) const { diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 5c4965c919ac30..e284ca82420bac 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -753,18 +754,18 @@ class presburger::SetCoalescer { /// \___\|/ \_____/ /// /// - bool coalescePairCutCase(unsigned i, unsigned j); + LogicalResult coalescePairCutCase(unsigned i, unsigned j); /// Types the inequality `ineq` according to its `IneqType` for `simp` into /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. - bool typeInequality(ArrayRef ineq, Simplex &simp); + LogicalResult typeInequality(ArrayRef ineq, Simplex &simp); /// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and /// -`eq` >= 0 according to their `IneqType` for `simp` into /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate /// inequalities were encountered. Otherwise, returns failure. - bool typeEquality(ArrayRef eq, Simplex &simp); + LogicalResult typeEquality(ArrayRef eq, Simplex &simp); /// Replaces the element at position `i` with the last element and erases /// the last element for both `disjuncts` and `simplices`. @@ -775,7 +776,7 @@ class presburger::SetCoalescer { /// successfully coalesced. The simplices in `simplices` need to be the ones /// constructed from `disjuncts`. At this point, there are no empty /// disjuncts in `disjuncts` left. - bool coalescePair(unsigned i, unsigned j); + LogicalResult coalescePair(unsigned i, unsigned j); }; /// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty @@ -818,7 +819,7 @@ PresburgerRelation SetCoalescer::coalesce() { cuttingIneqsB.clear(); if (i == j) continue; - if (coalescePair(i, j)) { + if (coalescePair(i, j).succeeded()) { broken = true; break; } @@ -902,7 +903,7 @@ void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j, /// \___\|/ \_____/ /// /// -bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) { +LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) { /// All inequalities of `b` need to be redundant. We already know that the /// redundant ones are, so only the cutting ones remain to be checked. Simplex &simp = simplices[i]; @@ -910,7 +911,7 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) { if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef curr) { return !isFacetContained(curr, simp); })) - return false; + return failure(); IntegerRelation newSet(disjunct.getSpace()); for (ArrayRef curr : redundantIneqsA) @@ -920,23 +921,25 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) { newSet.addInequality(curr); addCoalescedDisjunct(i, j, newSet); - return true; + return success(); } -bool SetCoalescer::typeInequality(ArrayRef ineq, Simplex &simp) { +LogicalResult SetCoalescer::typeInequality(ArrayRef ineq, + Simplex &simp) { Simplex::IneqType type = simp.findIneqType(ineq); if (type == Simplex::IneqType::Redundant) redundantIneqsB.push_back(ineq); else if (type == Simplex::IneqType::Cut) cuttingIneqsB.push_back(ineq); else - return false; - return true; + return failure(); + return success(); } -bool SetCoalescer::typeEquality(ArrayRef eq, Simplex &simp) { - if (!typeInequality(eq, simp)) - return false; +LogicalResult SetCoalescer::typeEquality(ArrayRef eq, + Simplex &simp) { + if (typeInequality(eq, simp).failed()) + return failure(); negEqs.push_back(getNegatedCoeffs(eq)); ArrayRef inv(negEqs.back()); return typeInequality(inv, simp); @@ -951,7 +954,7 @@ void SetCoalescer::eraseDisjunct(unsigned i) { simplices.pop_back(); } -bool SetCoalescer::coalescePair(unsigned i, unsigned j) { +LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) { IntegerRelation &a = disjuncts[i]; IntegerRelation &b = disjuncts[j]; @@ -959,7 +962,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) { /// skipped. /// TODO: implement local id support. if (a.getNumLocalVars() != 0 || b.getNumLocalVars() != 0) - return false; + return failure(); Simplex &simpA = simplices[i]; Simplex &simpB = simplices[j]; @@ -969,34 +972,34 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) { // inequality is encountered during typing, the two IntegerRelations // cannot be coalesced. for (int k = 0, e = a.getNumInequalities(); k < e; ++k) - if (!typeInequality(a.getInequality(k), simpB)) - return false; + if (typeInequality(a.getInequality(k), simpB).failed()) + return failure(); for (int k = 0, e = a.getNumEqualities(); k < e; ++k) - if (!typeEquality(a.getEquality(k), simpB)) - return false; + if (typeEquality(a.getEquality(k), simpB).failed()) + return failure(); std::swap(redundantIneqsA, redundantIneqsB); std::swap(cuttingIneqsA, cuttingIneqsB); for (int k = 0, e = b.getNumInequalities(); k < e; ++k) - if (!typeInequality(b.getInequality(k), simpA)) - return false; + if (typeInequality(b.getInequality(k), simpA).failed()) + return failure(); for (int k = 0, e = b.getNumEqualities(); k < e; ++k) - if (!typeEquality(b.getEquality(k), simpA)) - return false; + if (typeEquality(b.getEquality(k), simpA).failed()) + return failure(); // If there are no cutting inequalities of `a`, `b` is contained // within `a`. if (cuttingIneqsA.empty()) { eraseDisjunct(j); - return true; + return success(); } // Try to apply the cut case - if (coalescePairCutCase(i, j)) - return true; + if (coalescePairCutCase(i, j).succeeded()) + return success(); // Swap the vectors to compare the pair (j,i) instead of (i,j). std::swap(redundantIneqsA, redundantIneqsB); @@ -1006,7 +1009,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) { // within `a`. if (cuttingIneqsA.empty()) { eraseDisjunct(i); - return true; + return success(); } // Try to apply the cut case diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 4efc7a3755014a..bebbf0325f430c 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -229,7 +230,7 @@ Direction flippedDirection(Direction direction) { /// add these to the set of ignored columns and continue to the next row. If we /// run out of rows, then A*y is zero and we are done. MaybeOptimum> LexSimplex::findRationalLexMin() { - if (!restoreRationalConsistency()) { + if (restoreRationalConsistency().failed()) { markEmpty(); return OptimumKind::Empty; } @@ -274,7 +275,7 @@ MaybeOptimum> LexSimplex::findRationalLexMin() { /// /// The constraint is violated when added (it would be useless otherwise) /// so we immediately try to move it to a column. -bool LexSimplexBase::addCut(unsigned row) { +LogicalResult LexSimplexBase::addCut(unsigned row) { DynamicAPInt d = tableau(row, 0); unsigned cutRow = addZeroRow(/*makeRestricted=*/true); tableau(cutRow, 0) = d; @@ -301,7 +302,7 @@ std::optional LexSimplex::maybeGetNonIntegralVarRow() const { MaybeOptimum> LexSimplex::findIntegerLexMin() { // We first try to make the tableau consistent. - if (!restoreRationalConsistency()) + if (restoreRationalConsistency().failed()) return OptimumKind::Empty; // Then, if the sample value is integral, we are done. @@ -316,9 +317,9 @@ MaybeOptimum> LexSimplex::findIntegerLexMin() { // // Failure indicates that the tableau became empty, which occurs when the // polytope is integer empty. - if (!addCut(*maybeRow)) + if (addCut(*maybeRow).failed()) return OptimumKind::Empty; - if (!restoreRationalConsistency()) + if (restoreRationalConsistency().failed()) return OptimumKind::Empty; } @@ -411,7 +412,7 @@ bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const { /// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0 /// This constraint is violated when added so we immediately try to move it to a /// column. -bool SymbolicLexSimplex::addSymbolicCut(unsigned row) { +LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { DynamicAPInt d = tableau(row, 0); if (isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), d)) { // The coefficients of symbols in the symbol numerator are divisible @@ -523,11 +524,11 @@ std::optional SymbolicLexSimplex::maybeGetNonIntegralVarRow() { /// The non-branching pivots are just the ones moving the rows /// that are always violated in the symbol domain. -bool SymbolicLexSimplex::doNonBranchingPivots() { +LogicalResult SymbolicLexSimplex::doNonBranchingPivots() { while (std::optional row = maybeGetAlwaysViolatedRow()) - if (!moveRowUnknownToColumn(*row)) - return false; - return true; + if (moveRowUnknownToColumn(*row).failed()) + return failure(); + return success(); } SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { @@ -567,7 +568,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { continue; } - if (!doNonBranchingPivots()) { + if (doNonBranchingPivots().failed()) { // Could not find pivots for violated constraints; return. --level; continue; @@ -627,7 +628,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { // The tableau is rationally consistent for the current domain. // Now we look for non-integral sample values and add cuts for them. if (std::optional row = maybeGetNonIntegralVarRow()) { - if (!addSymbolicCut(*row)) { + if (addSymbolicCut(*row).failed()) { // No integral points; return. --level; continue; @@ -661,7 +662,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() { SmallVector splitIneq = getComplementIneq(getSymbolicSampleIneq(u.pos)); normalizeRange(splitIneq); - if (!moveRowUnknownToColumn(u.pos)) { + if (moveRowUnknownToColumn(u.pos).failed()) { // The unknown can't be made non-negative; return. --level; continue; @@ -699,13 +700,13 @@ std::optional LexSimplex::maybeGetViolatedRow() const { /// We simply look for violated rows and keep trying to move them to column /// orientation, which always succeeds unless the constraints have no solution /// in which case we just give up and return. -bool LexSimplex::restoreRationalConsistency() { +LogicalResult LexSimplex::restoreRationalConsistency() { if (empty) - return false; + return failure(); while (std::optional maybeViolatedRow = maybeGetViolatedRow()) - if (!moveRowUnknownToColumn(*maybeViolatedRow)) - return false; - return true; + if (moveRowUnknownToColumn(*maybeViolatedRow).failed()) + return failure(); + return success(); } // Move the row unknown to column orientation while preserving lexicopositivity @@ -770,7 +771,7 @@ bool LexSimplex::restoreRationalConsistency() { // which is in contradiction to the fact that B.col(j) / B(i,j) must be // lexicographically smaller than B.col(k) / B(i,k), since it lexicographically // minimizes the change in sample value. -bool LexSimplexBase::moveRowUnknownToColumn(unsigned row) { +LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) { std::optional maybeColumn; for (unsigned col = 3 + nSymbol, e = getNumColumns(); col < e; ++col) { if (tableau(row, col) <= 0) @@ -780,10 +781,10 @@ bool LexSimplexBase::moveRowUnknownToColumn(unsigned row) { } if (!maybeColumn) - return false; + return failure(); pivot(row, *maybeColumn); - return true; + return success(); } unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA, @@ -986,7 +987,7 @@ void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) { /// Perform pivots until the unknown has a non-negative sample value or until /// no more upward pivots can be performed. Return success if we were able to /// bring the row to a non-negative sample value, and failure otherwise. -bool Simplex::restoreRow(Unknown &u) { +LogicalResult Simplex::restoreRow(Unknown &u) { assert(u.orientation == Orientation::Row && "unknown should be in row position"); @@ -997,9 +998,9 @@ bool Simplex::restoreRow(Unknown &u) { pivot(*maybePivot); if (u.orientation == Orientation::Column) - return true; // the unknown is unbounded above. + return success(); // the unknown is unbounded above. } - return tableau(u.pos, 1) >= 0; + return success(tableau(u.pos, 1) >= 0); } /// Find a row that can be used to pivot the column in the specified direction. @@ -1105,8 +1106,8 @@ void SimplexBase::markEmpty() { /// empty and we mark it as such. void Simplex::addInequality(ArrayRef coeffs) { unsigned conIndex = addRow(coeffs, /*makeRestricted=*/true); - bool result = restoreRow(con[conIndex]); - if (!result) + LogicalResult result = restoreRow(con[conIndex]); + if (result.failed()) markEmpty(); } @@ -1384,7 +1385,7 @@ MaybeOptimum Simplex::computeOptimum(Direction direction, MaybeOptimum optimum = computeRowOptimum(direction, row); if (u.restricted && direction == Direction::Down && (optimum.isUnbounded() || *optimum < Fraction(0, 1))) { - if (!restoreRow(u)) + if (restoreRow(u).failed()) llvm_unreachable("Could not restore row!"); } return optimum; @@ -1453,7 +1454,7 @@ void Simplex::detectRedundant(unsigned offset, unsigned count) { if (minimum.isUnbounded() || *minimum < Fraction(0, 1)) { // Constraint is unbounded below or can attain negative sample values and // hence is not redundant. - if (!restoreRow(u)) + if (restoreRow(u).failed()) llvm_unreachable("Could not restore non-redundant row!"); continue; } diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp index 65190c6f07d4b6..9b32972de2e0a2 100644 --- a/mlir/lib/Analysis/Presburger/Utils.cpp +++ b/mlir/lib/Analysis/Presburger/Utils.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/Presburger/PresburgerSpace.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -95,10 +96,10 @@ static void normalizeDivisionByGCD(MutableArrayRef dividend, /// If successful, `expr` is set to dividend of the division and `divisor` is /// set to the denominator of the division, which will be positive. /// The final division expression is normalized by GCD. -static bool getDivRepr(const IntegerRelation &cst, unsigned pos, - unsigned ubIneq, unsigned lbIneq, - MutableArrayRef expr, - DynamicAPInt &divisor) { +static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, + unsigned ubIneq, unsigned lbIneq, + MutableArrayRef expr, + DynamicAPInt &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(ubIneq <= cst.getNumInequalities() && @@ -120,7 +121,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, break; if (i < e) - return false; + return failure(); // Then, check if the constant term is of the proper form. // Due to the form of the upper/lower bound inequalities, the sum of their @@ -132,7 +133,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. // This also implictly checks that `divisor` is positive. if (!(0 <= c && c <= divisor - 1)) // NOLINT - return false; + return failure(); // The inequality pair can be used to extract the division. // Set `expr` to the dividend of the division except the constant term, which @@ -147,7 +148,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c; normalizeDivisionByGCD(expr, divisor); - return true; + return success(); } /// Check if the pos^th variable can be represented as a division using @@ -161,9 +162,10 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, /// If successful, `expr` is set to dividend of the division and `divisor` is /// set to the denominator of the division. The final division expression is /// normalized by GCD. -static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd, - MutableArrayRef expr, - DynamicAPInt &divisor) { +static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos, + unsigned eqInd, + MutableArrayRef expr, + DynamicAPInt &divisor) { assert(pos <= cst.getNumVars() && "Invalid variable position"); assert(eqInd <= cst.getNumEqualities() && "Invalid equality position"); @@ -174,7 +176,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd, // Equality must involve the pos-th variable and hence `tempDiv` != 0. DynamicAPInt tempDiv = cst.atEq(eqInd, pos); if (tempDiv == 0) - return false; + return failure(); int signDiv = tempDiv < 0 ? -1 : 1; // The divisor is always a positive integer. @@ -187,7 +189,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd, expr.back() = -signDiv * cst.atEq(eqInd, cst.getNumCols() - 1); normalizeDivisionByGCD(expr, divisor); - return true; + return success(); } // Returns `false` if the constraints depends on a variable for which an @@ -238,7 +240,7 @@ MaybeLocalRepr presburger::computeSingleVarRepr( for (unsigned ubPos : ubIndices) { for (unsigned lbPos : lbIndices) { // Attempt to get divison representation from ubPos, lbPos. - if (!getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor)) + if (getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor).failed()) continue; if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos)) @@ -251,7 +253,7 @@ MaybeLocalRepr presburger::computeSingleVarRepr( } for (unsigned eqPos : eqIndices) { // Attempt to get divison representation from eqPos. - if (!getDivRepr(cst, pos, eqPos, dividend, divisor)) + if (getDivRepr(cst, pos, eqPos, dividend, divisor).failed()) continue; if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))