Skip to content

Commit

Permalink
[Arith][TIR] Check for constant offsets of known literal constraints (a…
Browse files Browse the repository at this point in the history
…pache#13023)

Previously, the checks for a literal constraint would find exact
matches for an inequality, but any alterations to the conditional
would break this exact matching.  This commit introduces checks for
constant offsets relative to a known value.  These checks are not
always expressible using the existing `ConstIntSetAnalyzer`, which
represents allowed values using a single contiguous
region.  (e.g. `i!=5` is not representable, because it requires a
region for `i<5` and another for `i>5`.)

This implementation reuses the internal representation for
inequalities introduced in apache#12863,
along with much of its implementation.  However, the indirect
comparisons (e.g. using `a < b` and `b < c` to prove that `a < c`)
introduced in that PR still require an explicit flag to be used.
  • Loading branch information
Lunderberg authored Oct 29, 2022
1 parent e971956 commit 25a0d47
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 45 deletions.
11 changes: 10 additions & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,19 @@ class TransitiveComparisonAnalyzer {
*
* \param rhs The right-hand side of the comparison
*
* \param propagate_inequalities If true, attempt to find a sequence
* of transitive inequalities that allow the lhs and rhs to be
* compared. If false, only use the known comparison that have been
* directly provided. Using `propagate_inequalities = false` is
* roughly equivalent to comparing against all known inequality
* expressions using `ExprDeepEqual`, but also allows for constant
* offsets on either side of the inequality.
*
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);
TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
bool propagate_inequalities = true);

/*! \brief Bind a variable as being equal to a known expression
*
Expand Down
7 changes: 3 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimE

if (is_finished()) return output;

if (enabled_extensions_ & kTransitivelyProveInequalities) {
output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
}
output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));

return output;
}
Expand All @@ -132,7 +130,8 @@ CompareResult RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimE

CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const PrimExpr& x,
const PrimExpr& y) {
return analyzer_->transitive_comparisons.TryCompare(x, y);
bool propagate_inequalities = enabled_extensions_ & kTransitivelyProveInequalities;
return analyzer_->transitive_comparisons.TryCompare(x, y, propagate_inequalities);
}

// try to prove x equals val
Expand Down
168 changes: 128 additions & 40 deletions src/arith/transitive_comparison_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,19 @@ class TransitiveComparisonAnalyzer::Impl {
*
* \param rhs The right-hand side of the comparison
*
* \param propagate_inequalities If true, attempt to find a sequence
* of transitive inequalities that allow the lhs and rhs to be
* compared. If false, only use the known comparison that have been
* directly provided. Using `propagate_inequalities = false` is
* roughly equivalent to comparing against all known values with
* `ExprDeepEqual`, but also allowing for constant offsets on either
* side of the inequality.
*
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const;
CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
bool propagate_inequalities = true) const;

/*! \brief Bind a variable as being equal to a known expression
*
Expand Down Expand Up @@ -192,7 +201,37 @@ class TransitiveComparisonAnalyzer::Impl {
*/
void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec);

/*! \brief Attempt to compare the expressions, starting at the lhs.
/*! Collect known comparisons between LHS and RHS, without propagation
*
* Allows the internal representation to handle any constant
* offsets, without searching for a sequence of inequalities.
*
* \param lhs_key The left-hand side of the comparison
*
* \param rhs_key The right-hand side of the comparison
*
* \returns A subset of `knowns_` and `scoped_knowns_`, filtered to
* only include comparisons between `lhs_key` and `rhs_key`,
* normalized such that `lhs_key` is on the left-hand side.
*/
std::vector<Comparison> CollectDirectComparisons(Key lhs_key, Key rhs_key) const;

/*! Collect known comparisons between LHS and RHS, with propagation
*
* \param lhs_key The left-hand side of the comparison
*
* \param rhs_key The right-hand side of the comparison
*
* \returns All comparisons between `lhs_key` and `rhs_key`,
* including the explicitly-provided comparisons in `knowns_` and
* `scoped_knowns_`, and comparisons provable through a series of
* comparisons through other values. All comparisons returned are
* between `lhs_key` and `rhs_key`, and are normalized such that
* `lhs_key` is on the left-hand side.
*/
std::vector<Comparison> CollectIndirectComparisons(Key lhs_key, Key rhs_key) const;

/*! \brief Internal function used by CollectIndirectComparisons
*
* Perform a depth-first search through the space of known
* expressions, starting at the LHS of a comparison. In this
Expand All @@ -208,14 +247,29 @@ class TransitiveComparisonAnalyzer::Impl {
* expression D, then combine the comparisons that compose the path
* into the expression A<=D-4.
*
* \param lhs The left-hand side of the comparison
* \param lhs_key The left-hand side of the comparison
*
* \param rhs The right-hand side of the comparison
* \param rhs_key The right-hand side of the comparison
*
* \returns A vector of comparisons between the two expressions.
*/
std::vector<Comparison> DFSFromLHS(Key lhs_key, Key rhs_key) const;

/*! \brief Combine a set of comparisons that share a LHS and RHS
*
* \param lhs_to_rhs The comparisons to merge. These should all
* have the same LHS and RHS. This parameter will typically be the
* result from `CollectDirectComparisons` or
* `CollectIndirectComparisons`.
*
* \return The result of the comparison
* \param offset The constant offset in the comparison being proven.
* This is extracted from any additive/subtractive constants in the
* `PrimExpr` arguments to `TryCompare`.
*
* \returns The possible comparisons between LHS and RHS provided
* inequalities.
*/
CompareResult DFSFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs,
const PrimExpr& rhs) const;
CompareResult MergeComparisons(const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const;

/*! \brief Previous Range bindings
*
Expand Down Expand Up @@ -475,8 +529,9 @@ bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {}
TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}

CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) {
return impl_->TryCompare(lhs, rhs);
CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs,
bool propagate_inequalities) {
return impl_->TryCompare(lhs, rhs, propagate_inequalities);
}

void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
Expand Down Expand Up @@ -547,7 +602,8 @@ std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const
}

CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr,
const PrimExpr& rhs_expr) const {
const PrimExpr& rhs_expr,
bool propagate_inequalities) const {
// Currently only supports integer checks
if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) {
return CompareResult::kUnknown;
Expand Down Expand Up @@ -575,29 +631,59 @@ CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs
return CompareResult::kUnknown;
}

auto from_lhs = DFSFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs);
auto from_rhs = Reverse(DFSFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs));
auto output = from_lhs & from_rhs;
auto lhs_to_rhs = [&]() {
if (propagate_inequalities) {
return CollectIndirectComparisons(lhs_key.value(), rhs_key.value());
} else {
return CollectDirectComparisons(lhs_key.value(), rhs_key.value());
}
}();
return MergeComparisons(lhs_to_rhs, offset);
}

std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::CollectDirectComparisons(Key lhs_key, Key rhs_key) const {
std::vector<Comparison> output;

auto append_known = [&](Comparison cmp) {
if (auto normalized = cmp.WithLHS(lhs_key)) {
if (normalized.value().rhs_ == rhs_key) {
output.push_back(normalized.value());
}
}
};

for (const auto& known : knowns_) {
append_known(known);
}
for (const auto& known : scoped_knowns_) {
append_known(known);
}

return output;
}

CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input, Key rhs_key_input,
int64_t offset_input,
const PrimExpr& lhs_input,
const PrimExpr& rhs_input) const {
Key lhs_key = lhs_key_input;
Key rhs_key = rhs_key_input;
int64_t offset = offset_input;
std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::CollectIndirectComparisons(Key lhs_key, Key rhs_key) const {
auto output = DFSFromLHS(lhs_key, rhs_key);
for (Comparison cmp : DFSFromLHS(rhs_key, lhs_key)) {
auto opt_normalized = cmp.WithLHS(lhs_key);
ICHECK(opt_normalized.has_value());
output.push_back(opt_normalized.value());
}
return output;
}

std::vector<TransitiveComparisonAnalyzer::Impl::Comparison>
TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key, Key rhs_key) const {
// Everything in `to_visit` has lhs as its lhs.
std::unordered_set<Key> seen;
std::unordered_set<Key> to_visit;
std::unordered_map<Key, std::vector<Comparison>> compared_to_x;
std::unordered_map<Key, std::vector<Comparison>> compared_to_lhs;

// Utility function to add a new known statement
auto declare_known = [&](Comparison cmp) {
std::vector<Comparison>& knowns = compared_to_x[cmp.rhs_];
std::vector<Comparison>& knowns = compared_to_lhs[cmp.rhs_];

// The comparison adds no new information, no modification
// required.
Expand Down Expand Up @@ -646,8 +732,8 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
Key middle_key = *to_visit.begin();
to_visit.erase(to_visit.begin());

std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key);
ICHECK(compared_to_x.count(middle_key));
std::vector<Comparison>& prev_knowns_using_middle = compared_to_lhs.at(middle_key);
ICHECK(compared_to_lhs.count(middle_key));

std::vector<Comparison> new_knowns_using_lhs;

Expand Down Expand Up @@ -721,51 +807,53 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
}
}

// It's possible that we don't have any transitive comparisons that
// can prove something about LHS and RHS.
auto it = compared_to_x.find(rhs_key);
if (it == compared_to_x.end()) {
return CompareResult::kUnknown;
if (auto it = compared_to_lhs.find(rhs_key); it != compared_to_lhs.end()) {
return it->second;
} else {
// There are known comparisons involving the LHS and the RHS, but
// no path that connects the two expressions.
return {};
}
}

const std::vector<Comparison>& known_between_lhs_and_rhs = it->second;

CompareResult TransitiveComparisonAnalyzer::Impl::MergeComparisons(
const std::vector<Comparison>& lhs_to_rhs, int64_t offset) const {
// Just because we found a comparison involving LHS and RHS doesn't
// mean that it's useful. e.g. Knowing that `x < y` doesn't let us
// prove whether `x + 5 < y`.
CompareResult result = CompareResult::kUnknown;
for (const auto& known : known_between_lhs_and_rhs) {
switch (known.result_) {
for (const auto& cmp : lhs_to_rhs) {
switch (cmp.result_) {
case CompareResult::kInconsistent:
result = CompareResult::kInconsistent;
break;

case CompareResult::kEQ:
if (offset == known.offset_) {
if (offset == cmp.offset_) {
result = result & CompareResult::kEQ;
} else {
result = result & CompareResult::kNE;
}
break;

case CompareResult::kLE:
if (known.offset_ < offset) {
if (cmp.offset_ < offset) {
result = result & CompareResult::kLT;
} else if (known.offset_ <= offset) {
} else if (cmp.offset_ <= offset) {
result = result & CompareResult::kLE;
}
break;

case CompareResult::kGE:
if (known.offset_ > offset) {
if (cmp.offset_ > offset) {
result = result & CompareResult::kGT;
} else if (known.offset_ >= offset) {
} else if (cmp.offset_ >= offset) {
result = result & CompareResult::kGE;
}
break;

case CompareResult::kNE:
if (offset == known.offset_) {
if (offset == cmp.offset_) {
result = result & CompareResult::kNE;
}
break;
Expand All @@ -779,7 +867,7 @@ CompareResult TransitiveComparisonAnalyzer::Impl::DFSFromLHS(Key lhs_key_input,
return CompareResult::kInconsistent;

default:
LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(known.result_);
LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(cmp.result_);
return CompareResult::kInconsistent;
}
}
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,5 +989,19 @@ def expected(A: T.Buffer[1, "bool"], n: T.int32, m: T.int32):
A[0] = n < m + 10


class TestProvableConditionWithOffset(BaseBeforeAfter):
"""Use scoped-constraint to prove inequalities"""

transitively_prove_inequalities = False

def before(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
if i < j:
A[0] = i < j + 1

def expected(A: T.Buffer[1, "bool"], i: T.int32, j: T.int32):
if i < j:
A[0] = True


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 25a0d47

Please sign in to comment.