Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith][Refactor] Extract And/Or/Not handling from RewriteSimplifier #12942

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 171 additions & 90 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/tir/op.h>

#include <algorithm>
#include <utility>

#include "../target/datatype/registry.h"
#include "const_fold.h"
Expand Down Expand Up @@ -232,17 +233,17 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
literal_constraints_.push_back(subconstraint);
// We could apply this during TryMatchLiteralConstraint, but
// that would require performing a rewrite of each expression
// being checked. This way, we only apply a rewrite for each
// constraint being applied.
PrimExpr negation;
if (subconstraint.dtype().is_bool()) {
negation = Not(subconstraint);
// We could apply RewriteBooleanOperators during
// TryMatchLiteralConstraint, but that would require
// performing a rewrite of each expression being checked.
// This way, we only apply a rewrite for each constraint being
// applied.
negation = RewriteBooleanOperators(Not(subconstraint));
} else {
negation = subconstraint == make_zero(subconstraint.dtype());
}
negation = operator()(negation);
literal_constraints_.push_back(Not(negation));
}
}
Expand Down Expand Up @@ -1301,7 +1302,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
}

Optional<PrimExpr> RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const {
PrimExpr negation = Not(expr);
PrimExpr negation = RewriteBooleanOperators(Not(expr));

ExprDeepEqual expr_equal;
for (const auto& constraint : literal_constraints_) {
Expand Down Expand Up @@ -1497,105 +1498,26 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<NotNode>();
if (auto const_res = TryConstFold<Not>(op->a)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

TVM_TRY_REWRITE(!(!x), x);
TVM_TRY_REWRITE(!(x <= y), y < x);
TVM_TRY_REWRITE(!(x >= y), x < y);
TVM_TRY_REWRITE(!(x < y), y <= x);
TVM_TRY_REWRITE(!(x > y), x <= y);
TVM_TRY_REWRITE(!(x == y), x != y);
TVM_TRY_REWRITE(!(x != y), x == y);
TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
return ret;
return RewriteBooleanOperators(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<AndNode>();
if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
}

auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);

TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value);

TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);
return ret;
return RewriteBooleanOperators(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<OrNode>();
if (auto const_res = TryConstFold<Or>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes));
}

auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));

TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
TVM_TRY_REWRITE(x || !x, ctrue);
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);

TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);

TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
return ret;
return RewriteBooleanOperators(ret);
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) {
Expand Down Expand Up @@ -1746,5 +1668,164 @@ RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent))

RewriteSimplifier::~RewriteSimplifier() { delete impl_; }

namespace {
/* Utility for rewriting only boolean portions of an expression
*
* Intended for application on an expression that has previously been
* simplified, but has subsequent manipulations performed.
* (e.g. Finding the simplified negation of a conditional without
* performing a full simplification.)
*/
class BooleanRewriter : public ExprMutator {
private:
PrimExpr VisitExpr(const PrimExpr& expr) override {
if (expr.dtype().is_bool()) {
return ExprMutator::VisitExpr(expr);
} else {
return expr;
}
}

PrimExpr VisitExpr_(const NotNode* op) override {
PrimExpr ret = GetRef<PrimExpr>(op);

if (auto const_res = TryConstFold<Not>(op->a)) return const_res.value();

PVar<PrimExpr> x, y;
PVar<int> lanes;
if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}

TVM_TRY_RECURSIVE_REWRITE(!(!x), x);
TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y));
TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y));
TVM_TRY_REWRITE(!(x <= y), y < x);
TVM_TRY_REWRITE(!(x >= y), x < y);
TVM_TRY_REWRITE(!(x < y), y <= x);
TVM_TRY_REWRITE(!(x > y), x <= y);
TVM_TRY_REWRITE(!(x == y), x != y);
TVM_TRY_REWRITE(!(x != y), x == y);

return ret;
}

PrimExpr VisitExpr_(const AndNode* op) override {
And ret = GetRef<And>(op);
if (allow_recursion_) {
allow_recursion_ = false;
PrimExpr a = VisitExpr(ret->a);
PrimExpr b = VisitExpr(ret->b);
bool is_same = a.same_as(ret->a) && b.same_as(b);
if (!is_same) {
ret = And(a, b);
}
allow_recursion_ = true;
}
op = ret.get();

if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
}

auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));

TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);

TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value);

TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2);
TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2);

return std::move(ret);
}

PrimExpr VisitExpr_(const OrNode* op) override {
Or ret = GetRef<Or>(op);
if (allow_recursion_) {
allow_recursion_ = false;
PrimExpr a = VisitExpr(ret->a);
PrimExpr b = VisitExpr(ret->b);
bool is_same = a.same_as(ret->a) && b.same_as(b);
if (!is_same) {
ret = Or(a, b);
}
allow_recursion_ = true;
}
op = ret.get();

if (auto const_res = TryConstFold<Or>(op->a, op->b)) return const_res.value();

// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<int> lanes;

if (op->dtype.lanes() != 1) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes));
}

auto ctrue = PConst<PrimExpr>(make_const(op->dtype, true));

TVM_TRY_REWRITE(x == y || x != y, ctrue);
TVM_TRY_REWRITE(x != y || x == y, ctrue);
TVM_TRY_REWRITE(x || !x, ctrue);
TVM_TRY_REWRITE(x <= y || y < x, ctrue);
TVM_TRY_REWRITE(y < x || x <= y, ctrue);

TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value);

TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1);

TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);

return std::move(ret);
}

PrimExpr RecursiveRewrite(const PrimExpr& expr) {
bool cache = true;
std::swap(cache, allow_recursion_);
auto output = VisitExpr(expr);
std::swap(cache, allow_recursion_);
return output;
}

private:
bool allow_recursion_{false};
};
} // namespace

PrimExpr RewriteBooleanOperators(const PrimExpr& expr) { return BooleanRewriter()(expr); }

} // namespace arith
} // namespace tvm
14 changes: 14 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
}
};

/* Utility for rewriting only boolean portions of an expression
*
* Intended for application on an expression that has previously been
* simplified, but has subsequent manipulations performed.
* (e.g. Finding the simplified negation of a conditional without
* performing a full simplification.) Only a single simplication step
* is performed.
*
* \param expr The boolean expression to be simplified
*
* \returns The simplified boolean expression
*/
PrimExpr RewriteBooleanOperators(const PrimExpr& expr);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_REWRITE_SIMPLIFY_H_