Skip to content

Commit

Permalink
[TIR][Arith] Prove conditionals by transitively applying knowns
Browse files Browse the repository at this point in the history
This commit adds a new sub-analyzer, `TransitiveComparisonAnalyzer`,
which attempts to apply multiple known comparisons to prove an
unknown.  For example, `a <= b` and `b <= c` imply that `a <= c`.
These simplifications are necessary for simplifying conditionals
resulting from padded layout
transformations (apache#12261).

While some of these conditions may be proven using
`ConstIntBoundAnalyzer` or `IntSetAnalyzer`, each has some
limitations.  `ConstIntBoundAnalyzer` can only compare against a
constant, `IntSetAnalyzer` internally calls `RewriteSimplifier` which
can result in infinite recursion, and neither can handle not-equal
conditions because it would require tracking multiple intervals per
expression.  Therefore, introducing a new sub-analyzer for these
simplifications.
  • Loading branch information
Lunderberg committed Sep 21, 2022
1 parent fdc6894 commit ddcedff
Show file tree
Hide file tree
Showing 7 changed files with 941 additions and 25 deletions.
78 changes: 78 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,82 @@ class CanonicalSimplifier {
Impl* impl_;
};

/*! \brief Structure for representing result of known
*
* Values are assigned to allow these flags to be used in bitwise
* operations.
*/
enum class CompareResult : int {
kInconsistent = 0,
kEQ = 1,
kLT = 2,
kLE = 3,
kGT = 4,
kGE = 5,
kNE = 6,
kUnknown = 7
};

inline constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs) {
return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
}
inline constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs) {
return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
}

/*!
* \brief Using previously specified knowns, compare the expressions provided
*
* Given known expressions [(a OP b), (b OP c), ..., (y OP z)], search
* for a known result for `(a OP z)`.
*/
class TransitiveComparisonAnalyzer {
public:
/* \brief Using previously specified knowns, compare the expressions provided
*
* \param lhs The left-hand side of the comparison
*
* \param rhs The right-hand side of the comparison
*
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);

/*! \brief Bind a variable as being equal to a known expression
*
* \param var The variable of interest.
* \param expr The bound expression
* \param allow_override Whether to allow override of existing information.
*/
void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);

/*! \brief Bind a variable as being within a specified range
*
* \param var The variable of interest.
* \param range The known range
* \param allow_override Whether to allow override of existing information.
*/
void Bind(const Var& var, const Range& range, bool allow_override = false);

/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
friend class ConstraintContext;
TransitiveComparisonAnalyzer();
TVM_DLL ~TransitiveComparisonAnalyzer();
class Impl;
/*! \brief Internal impl */
std::unique_ptr<Impl> impl_{nullptr};
};

/*!
* \brief Constraint context.
*
Expand Down Expand Up @@ -437,6 +513,8 @@ class TVM_DLL Analyzer {
CanonicalSimplifier canonical_simplify;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer int_set;
/*! \brief sub-analyzer transitive comparisons */
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief constructor */
Analyzer();
/*!
Expand Down
3 changes: 3 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
this->transitive_comparisons.Bind(var, expr, allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -54,6 +55,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
} else {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
this->transitive_comparisons.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
Expand All @@ -72,6 +74,7 @@ void ConstraintContext::EnterWithScope() {
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
}

void ConstraintContext::ExitWithScope() {
Expand Down
10 changes: 6 additions & 4 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (TryCompare(temp, cval) != kLT) {
if (TryCompare(temp, cval) != CompareResult::kLT) {
lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
}
}
Expand Down Expand Up @@ -945,7 +945,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) {
if (!(TryCompare(temp, cval) == CompareResult::kLT &&
analyzer_->CanProveGreaterEqual(temp, 0))) {
lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
}
}
Expand Down Expand Up @@ -1052,7 +1053,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
return truncmod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT) {
if (TryCompare(temp, cval) == CompareResult::kLT) {
return temp;
} else {
// contonue to use logic below.
Expand Down Expand Up @@ -1113,7 +1114,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
return floormod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) {
if (TryCompare(temp, cval) == CompareResult::kLT &&
analyzer_->CanProveGreaterEqual(temp, 0)) {
return temp;
} else {
// contonue to use logic below.
Expand Down
63 changes: 45 additions & 18 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,42 +71,67 @@ using namespace tir;
// handled by CanonicalSimplifier.
//

CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) {
CompareResult output = CompareResult::kUnknown;

auto is_finished = [&output]() {
return output == CompareResult::kEQ || output == CompareResult::kLT ||
output == CompareResult::kGT;
};

output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));

if (is_finished()) return output;
output = CompareResult(output & TryCompareUsingConstIntBounds(x, y));

return output;
}

CompareResult RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimExpr& x,
const PrimExpr y) {
return TryCompare(x - y, 0);
}

CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const PrimExpr& x,
const PrimExpr& y) {
return analyzer_->transitive_comparisons.TryCompare(x, y);
}

// try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x,
int64_t val) {
CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) {
PrimExpr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImmNode>()) {
if (ptr->value == val) {
return kEQ;
return CompareResult::kEQ;
} else if (ptr->value > val) {
return kGT;
return CompareResult::kGT;
} else if (ptr->value < val) {
return kLT;
return CompareResult::kLT;
}
}
ConstIntBound dbound = analyzer_->const_int_bound(diff);
if (dbound->min_value == val && dbound->max_value == val) {
return kEQ;
return CompareResult::kEQ;
}
if (dbound->min_value > val) {
return kGT;
return CompareResult::kGT;
}
if (dbound->max_value < val) {
return kLT;
return CompareResult::kLT;
}
if (dbound->min_value >= val) {
return kGE;
return CompareResult::kGE;
}
if (dbound->max_value <= val) {
return kLE;
return CompareResult::kLE;
}
if (val == 0) {
ModularSet dmod = analyzer_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
return CompareResult::kNE;
}
}
return kUnknown;
return CompareResult::kUnknown;
}

void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
Expand Down Expand Up @@ -1333,10 +1358,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kEQ) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kEQ) {
return make_const(op->dtype, true);
} else if (result == kNE || result == kGT || result == kLT) {
} else if (result == CompareResult::kNE || result == CompareResult::kGT ||
result == CompareResult::kLT) {
return make_const(op->dtype, false);
}
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
Expand Down Expand Up @@ -1382,11 +1408,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kLT) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kLT) {
return make_const(op->dtype, true);
}
if (result == kEQ || result == kGT || result == kGE) {
if (result == CompareResult::kEQ || result == CompareResult::kGT ||
result == CompareResult::kGE) {
return make_const(op->dtype, false);
}

Expand Down
15 changes: 12 additions & 3 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
std::function<void()> EnterConstraint(const PrimExpr& constraint);

protected:
/*! \brief internal structure for comparison. */
enum CompareResult { kUnknown, kEQ, kGT, kGE, kLT, kLE, kNE };
// counter to record recursive rewrite depth.
int recur_depth_{0};
// internal variable map
Expand All @@ -98,6 +96,14 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
*/
CompareResult TryCompare(const PrimExpr& x, int64_t val);

/*! Try to compare x against y
*
* \param x The lhs of the comparison
* \param y The rhs of the comparison
* \return comparison result.
*/
CompareResult TryCompare(const PrimExpr& x, const PrimExpr& y);

/*!
* \brief Internal function to check whether or not to inline let.
* \param op The let expr.
Expand All @@ -115,6 +121,9 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
Optional<PrimExpr> TryMatchLiteralConstraint(const PrimExpr& expr) const;

private:
CompareResult TryCompareUsingKnownInequalities(const PrimExpr& x, const PrimExpr& y);
CompareResult TryCompareUsingConstIntBounds(const PrimExpr& x, const PrimExpr y);

// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
return analyzer_->CanProveGreaterEqual(x, val);
Expand All @@ -124,7 +133,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
// Whether x == val
bool CanProveEqual(const PrimExpr& x, int64_t val) {
// TODO(tqchen) refer back to super-analyzer.
return TryCompare(x, val) == kEQ;
return TryCompare(x, val) == CompareResult::kEQ;
}

// Recursive rewrite x
Expand Down
Loading

0 comments on commit ddcedff

Please sign in to comment.