Skip to content

Commit

Permalink
mlir/Presburger: reinstate use of LogicalResult (llvm#97415)
Browse files Browse the repository at this point in the history
Follow up on a desire post-landing d0fee98 (mlir/Presburger: strip
dependency on MLIRSupport) to reinstate the use of LogicalResult in
Presburger. Since db791b2 (mlir/LogicalResult: move into llvm),
LogicalResult is in LLVM, and fulfilling this desire is possible while
still maintaining the goal of stripping the Presburger library of mlir
dependencies.
  • Loading branch information
artagnon authored and lravenclaw committed Jul 3, 2024
1 parent 2024be2 commit afbf57e
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 94 deletions.
12 changes: 8 additions & 4 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/DynamicAPInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <optional>

namespace mlir {
namespace presburger {
using llvm::DynamicAPInt;
using llvm::failure;
using llvm::int64fromDynamicAPInt;
using llvm::LogicalResult;
using llvm::SmallVectorImpl;
using llvm::success;

class IntegerRelation;
class IntegerPolyhedron;
Expand Down Expand Up @@ -478,7 +482,7 @@ class IntegerRelation {
/// equality detection; if successful, the constant is substituted for the
/// variable everywhere in the constraint system and then removed from the
/// system.
bool constantFoldVar(unsigned pos);
LogicalResult constantFoldVar(unsigned pos);

/// This method calls `constantFoldVar` for the specified range of variables,
/// `num` variables starting at position `pos`.
Expand All @@ -501,7 +505,7 @@ class IntegerRelation {
/// 3) this = {0 <= d0 <= 5, 1 <= d1 <= 9}
/// other = {2 <= d0 <= 6, 5 <= d1 <= 15},
/// output = {0 <= d0 <= 6, 1 <= d1 <= 15}
bool unionBoundingBox(const IntegerRelation &other);
LogicalResult unionBoundingBox(const IntegerRelation &other);

/// Returns the smallest known constant bound for the extent of the specified
/// variable (pos^th), i.e., the smallest known constant that is greater
Expand Down Expand Up @@ -774,8 +778,8 @@ class IntegerRelation {
/// Eliminates a single variable at `position` from equality and inequality
/// constraints. Returns `success` if the variable was eliminated, and
/// `failure` otherwise.
inline bool gaussianEliminateVar(unsigned position) {
return gaussianEliminateVars(position, position + 1) == 1;
inline LogicalResult gaussianEliminateVar(unsigned position) {
return success(gaussianEliminateVars(position, position + 1) == 1);
}

/// Removes local variables using equalities. Each equality is checked if it
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Analysis/Presburger/Simplex.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class LexSimplexBase : public SimplexBase {
/// lexicopositivity of the basis transform. The row must have a non-positive
/// sample value. If this is not possible, return failure. This occurs when
/// the constraints have no solution or the sample value is zero.
bool moveRowUnknownToColumn(unsigned row);
LogicalResult moveRowUnknownToColumn(unsigned row);

/// Given a row that has a non-integer sample value, add an inequality to cut
/// away this fractional sample value from the polytope without removing any
Expand All @@ -459,7 +459,7 @@ class LexSimplexBase : public SimplexBase {
///
/// Return failure if the tableau became empty, and success if it didn't.
/// Failure status indicates that the polytope was integer empty.
bool addCut(unsigned row);
LogicalResult addCut(unsigned row);

/// Undo the addition of the last constraint. This is only called while
/// rolling back.
Expand Down Expand Up @@ -511,7 +511,7 @@ class LexSimplex : public LexSimplexBase {
MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;

/// Make the tableau configuration consistent.
bool restoreRationalConsistency();
LogicalResult restoreRationalConsistency();

/// Return whether the specified row is violated;
bool rowIsViolated(unsigned row) const;
Expand Down Expand Up @@ -626,7 +626,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// Return failure if the tableau became empty, indicating that the polytope
/// is always integer empty in the current symbol domain.
/// Return success otherwise.
bool doNonBranchingPivots();
LogicalResult doNonBranchingPivots();

/// Get a row that is always violated in the current domain, if one exists.
std::optional<unsigned> maybeGetAlwaysViolatedRow();
Expand All @@ -647,7 +647,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// at the time of the call. (This function may modify the symbol domain, but
/// failure statu indicates that the polytope was empty for all symbol values
/// in the initial domain.)
bool addSymbolicCut(unsigned row);
LogicalResult addSymbolicCut(unsigned row);

/// Get the numerator of the symbolic sample of the specific row.
/// This is an affine expression in the symbols with integer coefficients.
Expand Down Expand Up @@ -820,7 +820,7 @@ class Simplex : public SimplexBase {
///
/// Returns success if the unknown was successfully restored to a non-negative
/// sample value, failure otherwise.
bool restoreRow(Unknown &u);
LogicalResult restoreRow(Unknown &u);

/// Find a pivot to change the sample value of row in the specified
/// direction while preserving tableau consistency, except that if the
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Analysis/FlatLinearValueConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,10 +1247,10 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
if (!areVarsAligned(*this, otherCst)) {
FlatLinearValueConstraints otherCopy(otherCst);
mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
return success(IntegerPolyhedron::unionBoundingBox(otherCopy));
return IntegerPolyhedron::unionBoundingBox(otherCopy);
}

return success(IntegerPolyhedron::unionBoundingBox(otherCst));
return IntegerPolyhedron::unionBoundingBox(otherCst);
}

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 14 additions & 12 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -1552,22 +1553,22 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
return -1;
}

bool IntegerRelation::constantFoldVar(unsigned pos) {
LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
int rowIdx;
if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
return false;
return failure();

// atEq(rowIdx, pos) is either -1 or 1.
assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
setAndEliminate(pos, constVal);
return true;
return success();
}

void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
if (!constantFoldVar(t))
if (constantFoldVar(t).failed())
t++;
}
}
Expand Down Expand Up @@ -1944,9 +1945,9 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
if (atEq(r, pos) != 0) {
// Use Gaussian elimination here (since we have an equality).
bool ret = gaussianEliminateVar(pos);
LogicalResult ret = gaussianEliminateVar(pos);
(void)ret;
assert(ret && "Gaussian elimination guaranteed to succeed");
assert(ret.succeeded() && "Gaussian elimination guaranteed to succeed");
LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
LLVM_DEBUG(dump());
return;
Expand Down Expand Up @@ -2173,7 +2174,8 @@ static void getCommonConstraints(const IntegerRelation &a,

// Computes the bounding box with respect to 'other' by finding the min of the
// lower bounds and the max of the upper bounds along each of the dimensions.
bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
LogicalResult
IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
assert(getNumLocalVars() == 0 && "local ids not supported yet here");

Expand Down Expand Up @@ -2201,13 +2203,13 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
if (!extent.has_value())
// TODO: symbolic extents when necessary.
// TODO: handle union if a dimension is unbounded.
return false;
return failure();

auto otherExtent = otherCst.getConstantBoundOnDimSize(
d, &otherLb, &otherLbFloorDivisor, &otherUb);
if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
// TODO: symbolic extents when necessary.
return false;
return failure();

assert(lbFloorDivisor > 0 && "divisor always expected to be positive");

Expand All @@ -2227,7 +2229,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
auto constLb = getConstantBound(BoundType::LB, d);
auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
if (!constLb.has_value() || !constOtherLb.has_value())
return false;
return failure();
std::fill(minLb.begin(), minLb.end(), 0);
minLb.back() = std::min(*constLb, *constOtherLb);
}
Expand All @@ -2243,7 +2245,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
auto constUb = getConstantBound(BoundType::UB, d);
auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
if (!constUb.has_value() || !constOtherUb.has_value())
return false;
return failure();
std::fill(maxUb.begin(), maxUb.end(), 0);
maxUb.back() = std::max(*constUb, *constOtherUb);
}
Expand Down Expand Up @@ -2281,7 +2283,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
// union (since the above are just the union along dimensions); we shouldn't
// be discarding any other constraints on the symbols.

return true;
return success();
}

bool IntegerRelation::isColZero(unsigned pos) const {
Expand Down
59 changes: 31 additions & 28 deletions mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <functional>
Expand Down Expand Up @@ -753,18 +754,18 @@ class presburger::SetCoalescer {
/// \___\|/ \_____/
///
///
bool coalescePairCutCase(unsigned i, unsigned j);
LogicalResult coalescePairCutCase(unsigned i, unsigned j);

/// Types the inequality `ineq` according to its `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
bool typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);
LogicalResult typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);

/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
/// -`eq` >= 0 according to their `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
bool typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);
LogicalResult typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);

/// Replaces the element at position `i` with the last element and erases
/// the last element for both `disjuncts` and `simplices`.
Expand All @@ -775,7 +776,7 @@ class presburger::SetCoalescer {
/// successfully coalesced. The simplices in `simplices` need to be the ones
/// constructed from `disjuncts`. At this point, there are no empty
/// disjuncts in `disjuncts` left.
bool coalescePair(unsigned i, unsigned j);
LogicalResult coalescePair(unsigned i, unsigned j);
};

/// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty
Expand Down Expand Up @@ -818,7 +819,7 @@ PresburgerRelation SetCoalescer::coalesce() {
cuttingIneqsB.clear();
if (i == j)
continue;
if (coalescePair(i, j)) {
if (coalescePair(i, j).succeeded()) {
broken = true;
break;
}
Expand Down Expand Up @@ -902,15 +903,15 @@ void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
/// \___\|/ \_____/
///
///
bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
/// All inequalities of `b` need to be redundant. We already know that the
/// redundant ones are, so only the cutting ones remain to be checked.
Simplex &simp = simplices[i];
IntegerRelation &disjunct = disjuncts[i];
if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<DynamicAPInt> curr) {
return !isFacetContained(curr, simp);
}))
return false;
return failure();
IntegerRelation newSet(disjunct.getSpace());

for (ArrayRef<DynamicAPInt> curr : redundantIneqsA)
Expand All @@ -920,23 +921,25 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
newSet.addInequality(curr);

addCoalescedDisjunct(i, j, newSet);
return true;
return success();
}

bool SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp) {
LogicalResult SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq,
Simplex &simp) {
Simplex::IneqType type = simp.findIneqType(ineq);
if (type == Simplex::IneqType::Redundant)
redundantIneqsB.push_back(ineq);
else if (type == Simplex::IneqType::Cut)
cuttingIneqsB.push_back(ineq);
else
return false;
return true;
return failure();
return success();
}

bool SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp) {
if (!typeInequality(eq, simp))
return false;
LogicalResult SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq,
Simplex &simp) {
if (typeInequality(eq, simp).failed())
return failure();
negEqs.push_back(getNegatedCoeffs(eq));
ArrayRef<DynamicAPInt> inv(negEqs.back());
return typeInequality(inv, simp);
Expand All @@ -951,15 +954,15 @@ void SetCoalescer::eraseDisjunct(unsigned i) {
simplices.pop_back();
}

bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) {

IntegerRelation &a = disjuncts[i];
IntegerRelation &b = disjuncts[j];
/// Handling of local ids is not yet implemented, so these cases are
/// skipped.
/// TODO: implement local id support.
if (a.getNumLocalVars() != 0 || b.getNumLocalVars() != 0)
return false;
return failure();
Simplex &simpA = simplices[i];
Simplex &simpB = simplices[j];

Expand All @@ -969,34 +972,34 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
// inequality is encountered during typing, the two IntegerRelations
// cannot be coalesced.
for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
if (!typeInequality(a.getInequality(k), simpB))
return false;
if (typeInequality(a.getInequality(k), simpB).failed())
return failure();

for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
if (!typeEquality(a.getEquality(k), simpB))
return false;
if (typeEquality(a.getEquality(k), simpB).failed())
return failure();

std::swap(redundantIneqsA, redundantIneqsB);
std::swap(cuttingIneqsA, cuttingIneqsB);

for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
if (!typeInequality(b.getInequality(k), simpA))
return false;
if (typeInequality(b.getInequality(k), simpA).failed())
return failure();

for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
if (!typeEquality(b.getEquality(k), simpA))
return false;
if (typeEquality(b.getEquality(k), simpA).failed())
return failure();

// If there are no cutting inequalities of `a`, `b` is contained
// within `a`.
if (cuttingIneqsA.empty()) {
eraseDisjunct(j);
return true;
return success();
}

// Try to apply the cut case
if (coalescePairCutCase(i, j))
return true;
if (coalescePairCutCase(i, j).succeeded())
return success();

// Swap the vectors to compare the pair (j,i) instead of (i,j).
std::swap(redundantIneqsA, redundantIneqsB);
Expand All @@ -1006,7 +1009,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
// within `a`.
if (cuttingIneqsA.empty()) {
eraseDisjunct(i);
return true;
return success();
}

// Try to apply the cut case
Expand Down
Loading

0 comments on commit afbf57e

Please sign in to comment.