Skip to content

Commit

Permalink
[Arith][Fixup] Require feature flag for tighter inequality bounds (ap…
Browse files Browse the repository at this point in the history
…ache#16735)

This is a follow-up to apache#16588.  Due
to an incorrect rebase, the version that was merged into `main` had
the tighter `ConstIntBounds` enabled by default, rather than having
them implemented in `RewriteSimplifier`, gated behind a feature flag.
  • Loading branch information
Lunderberg authored Mar 19, 2024
1 parent ff6ce9c commit 48cedc7
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 173 deletions.
29 changes: 29 additions & 0 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,35 @@ class RewriteSimplifier {
* (n < 10) || (n < 5) => (n < 5)
*/
kApplyConstraintsToBooleanBranches = (1 << 2),

/* Special handling for expressions `(A+B)*C < (A*B)*D`
*
* Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
* when comparing the number of operations required for two
* different orderings in which matrix multiplications can be
* performed. Proving or disproving this conditional allows an
* optimal order of execution to be selected, even for dynamic
* argument shapes.
*
* The default behavior of `ConstIntBounds` assumes that each term
* in an expression is independent, and is insufficient to prove
* these inequalities. For example, the maximum value of `(A+B)*C
* - (A*B)*D` is determined by taking the maximum value of
* `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
* While this algorithm can be applied in all cases, the bound it
* provides is looser than strictly required.
*
* This extension adds a check for this case. When `A`, `B`, `C`,
* and `D` are all positive values, as is the case for tensor
* shapes, the inequality can be written as `1/A + 1/B < D/C`. If
* this inequality holds for the minimum values of `A`, `B`, and
* `D`, along with the maximum value of `C`, then the inequality
* holds for all values.
*
* This extension requires little to no performance overhead, and
* may be enabled by default in future releases.
*/
kComparisonOfProductAndSum = (1 << 3),
};

/*! \brief Enable an optional extension or extensions
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
estimate_region_strict_bound,
estimate_region_upper_bound,
)
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength, Extension
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
from .int_solver import solve_linear_equations, solve_linear_inequalities
Expand Down
38 changes: 36 additions & 2 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Arithmetic data structure and utility"""
from enum import IntEnum
import enum
from typing import Union

import tvm._ffi
Expand All @@ -26,13 +26,26 @@
from . import _ffi_api


class ProofStrength(IntEnum):
class ProofStrength(enum.IntEnum):
"""Proof strength of the analysis"""

DEFAULT = 0
SYMBOLIC_BOUND = 1


class Extension(enum.Flag):
"""Extensions enabled for RewriteSimplifier
Values should match `RewriteSimplifier::Extensions`
"""

NoExtensions = 0
TransitivelyProveInequalities = 1 << 0
ConvertBooleanToAndOfOrs = 1 << 1
ApplyConstraintsToBooleanBranches = 1 << 2
ComparisonOfProductAndSum = 1 << 3


@tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z"""
Expand Down Expand Up @@ -107,6 +120,8 @@ def __init__(self):
self._enter_constraint_context = _mod("enter_constraint_context")
self._can_prove_equal = _mod("can_prove_equal")
self._can_prove = _mod("can_prove")
self._get_enabled_extensions = _mod("get_enabled_extensions")
self._set_enabled_extensions = _mod("set_enabled_extensions")

def const_int_bound(self, expr):
"""Find constant integer bound for expr.
Expand Down Expand Up @@ -311,3 +326,22 @@ def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"):
Whether we can prove that lhs == rhs
"""
return self._can_prove_equal(lhs, rhs)

@property
def enabled_extensions(self) -> Extension:
"""Return the currently enabled extensions"""
value = self._get_enabled_extensions()
return Extension(value)

@enabled_extensions.setter
def enabled_extensions(self, flags: Union[int, Extension]):
"""Enable extensions for the analyzer
Parameters
----------
flags: Union[int,Extension]
The extensions to enable.
"""
flags = Extension(flags).value
self._set_enabled_extensions(flags)
10 changes: 10 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,16 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
} else if (name == "can_prove_equal") {
return PackedFunc(
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); });
} else if (name == "get_enabled_extensions") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions());
});
} else if (name == "set_enabled_extensions") {
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
std::int64_t flags = args[0];
self->rewrite_simplify.SetEnabledExtensions(
static_cast<RewriteSimplifier::Extension>(flags));
});
}
return PackedFunc();
};
Expand Down
168 changes: 0 additions & 168 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,6 @@ class ConstIntBoundAnalyzer::Impl
ret.min_value = InfAwareAdd(a.min_value, b.min_value);
ret.max_value = InfAwareAdd(a.max_value, b.max_value);

if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
ret = Intersect(ret, bound.value());
}

return ret;
}

Expand All @@ -254,12 +250,6 @@ class ConstIntBoundAnalyzer::Impl
ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
ret.max_value = InfAwareAdd(a.max_value, -b.min_value);

if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
ret = Intersect(ret, bound.value());
}
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
ret = Intersect(ret, Negative(bound.value()));
}
return ret;
}

Expand Down Expand Up @@ -775,164 +765,6 @@ class ConstIntBoundAnalyzer::Impl
std::ceil(std::log2(arg_bounds.max_value)));
}
}

std::optional<Entry> BoundUsingReciprocal(PrimExpr expr) {
// Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
// previous simplifications, the exact form of the expression may vary.
auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
PVar<PrimExpr> A, B, C, D;

if (PMatchesOneOf{
(A + B) * C - (A * B) * D,
(A + B) * C - (B * A) * D,
}
.Match(expr)) {
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
VisitExpr(D.Eval())};
} else if (PMatchesOneOf{
(A + B) * C - A * B,
(A + B) * C - B * A,
}
.Match(expr)) {
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
MakeBound(1, 1)};
} else if (PMatchesOneOf{
(A * B) * D - (A + B) * C,
(B * A) * D - (A + B) * C,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))};
} else if (PMatchesOneOf{
A * B - (A + B) * C,
B * A - (A + B) * C,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)};
} else if (PMatchesOneOf{
(A * B) * D + (A + B) * C,
(B * A) * D + (A + B) * C,
(A + B) * C + (A * B) * D,
(A + B) * C + (B * A) * D,
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))};
} else if (PMatchesOneOf{
(A * B) + (A + B) * C,
(B * A) + (A + B) * C,
(A + B) * C + (A * B),
(A + B) * C + (B * A),
}
.Match(expr)) {
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
VisitExpr(C.Eval()), MakeBound(-1, -1)};
} else {
return std::nullopt;
}
}();

if (!opt_special_case.has_value()) {
return std::nullopt;
}
// Unpacking the tuple would be cleaner with a structured binding.
// However, until C++20, structured bindings cannot be captured for
// use in a lambda function.
auto A_bound = std::get<0>(*opt_special_case);
auto B_bound = std::get<1>(*opt_special_case);
auto C_bound = std::get<2>(*opt_special_case);
auto D_bound = std::get<3>(*opt_special_case);

// If C and D have different signs, flip the signs of A/B/C so
// that C will match the sign of D.
if ((D_bound.max_value < 0 && C_bound.min_value > 0) ||
(D_bound.min_value > 0 && C_bound.max_value < 0)) {
A_bound = Negative(A_bound);
B_bound = Negative(B_bound);
C_bound = Negative(C_bound);
}

// If all terms are negative, then we'll be providing an upper bound
// rather than a lower bound. To avoid code duplication, flip all the
// signs here, find a lower bound, then flip the sign to produce the
// upper bound of the original expression.
bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
C_bound.max_value < 0 && D_bound.max_value < 0);
if (all_terms_negative) {
A_bound = Negative(A_bound);
B_bound = Negative(B_bound);
C_bound = Negative(C_bound);
D_bound = Negative(D_bound);
}

bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
C_bound.min_value > 0 && D_bound.min_value > 0);
if (!all_terms_positive) {
return std::nullopt;
}

// (A + B) * C - (A * B) * D
// (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
// (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
// (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
//
// The constant (A*B*C*D) is positive, and its minimum value is the
// product of the minimum values of A, B, C, and D. If the reciprocal
// term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
// be used to provide a lower bound on the expression.

bool reciprocal_term_is_positive = [&]() {
if (D_bound.max_value == ConstIntBound::kPosInf) {
// If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
// terms will approach zero, at which point the `-1/C` term
// will determine the sign the sign.
return false;
}

if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) {
// 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
// Since each term is positive, this condition can hold if either
// A*D <= C or B*D <= C.
return true;
}
if (A_bound.max_value != ConstIntBound::kPosInf &&
B_bound.max_value != ConstIntBound::kPosInf) {
// Even if neither term is sufficient on its own, if both A and B
// have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
// may still be provable.
//
// The maximum value of the LHS is found when C is minimized. The
// minimum value of the RHS is found when A, B, and D are
// maximized. If the condition holds in this case, then it holds
// in all cases.
//
// 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
// A_max*B_max*D_max < C_min*B_max + C_min*A_max
// A_max*B_max*D_max < C_min*(A_max + B_max)
//
if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
C_bound.min_value * (A_bound.max_value + B_bound.max_value)) {
return true;
}
}
return false;
}();

if (!reciprocal_term_is_positive) {
return std::nullopt;
}

auto ret = Everything(expr->dtype);
ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value;

// If we flipped the sign of the original expression, flip the sign of
// the resulting set of possible values.
if (all_terms_negative) {
ret = Negative(ret);
}
return ret;
}
};

ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {
Expand Down
Loading

0 comments on commit 48cedc7

Please sign in to comment.