diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 2f7b88dfc508..d7d1617c8d64 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -29,6 +29,7 @@ #include #include +#include #include "../target/datatype/registry.h" #include "const_fold.h" @@ -232,17 +233,17 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) { if (SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); - // We could apply this during TryMatchLiteralConstraint, but - // that would require performing a rewrite of each expression - // being checked. This way, we only apply a rewrite for each - // constraint being applied. PrimExpr negation; if (subconstraint.dtype().is_bool()) { - negation = Not(subconstraint); + // We could apply RewriteBooleanOperators during + // TryMatchLiteralConstraint, but that would require + // performing a rewrite of each expression being checked. + // This way, we only apply a rewrite for each constraint being + // applied. + negation = RewriteBooleanOperators(Not(subconstraint)); } else { negation = subconstraint == make_zero(subconstraint.dtype()); } - negation = operator()(negation); literal_constraints_.push_back(Not(negation)); } } @@ -1301,7 +1302,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { } Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint(const PrimExpr& expr) const { - PrimExpr negation = Not(expr); + PrimExpr negation = RewriteBooleanOperators(Not(expr)); ExprDeepEqual expr_equal; for (const auto& constraint : literal_constraints_) { @@ -1497,105 +1498,26 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); - // Pattern var to match any expression - PVar x, y; - PVar lanes; - if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); - } + if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); - TVM_TRY_REWRITE(!(!x), x); - TVM_TRY_REWRITE(!(x <= y), y < x); - TVM_TRY_REWRITE(!(x >= y), x < y); - TVM_TRY_REWRITE(!(x < y), y <= x); - TVM_TRY_REWRITE(!(x > y), x <= y); - TVM_TRY_REWRITE(!(x == y), x != y); - TVM_TRY_REWRITE(!(x != y), x == y); - TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y)); - TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y)); - return ret; + return RewriteBooleanOperators(ret); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); - // Pattern var to match any expression - PVar x, y; - // Pattern var match IntImm - PVar c1, c2; - PVar lanes; - - if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); - } - - auto cfalse = PConst(make_const(op->dtype, false)); - TVM_TRY_REWRITE(x == y && x != y, cfalse); - TVM_TRY_REWRITE(x != y && x == y, cfalse); - TVM_TRY_REWRITE(x && !x, cfalse); - TVM_TRY_REWRITE(x <= y && y < x, cfalse); - TVM_TRY_REWRITE(y < x && x <= y, cfalse); - - TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value); - - TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); - TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); - return ret; + return RewriteBooleanOperators(ret); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); - // Pattern var to match any expression - PVar x, y; - // Pattern var match IntImm - PVar c1, c2; - PVar lanes; - - if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); - } - - auto ctrue = PConst(make_const(op->dtype, true)); - - TVM_TRY_REWRITE(x == y || x != y, ctrue); - TVM_TRY_REWRITE(x != y || x == y, ctrue); - TVM_TRY_REWRITE(x || !x, ctrue); - TVM_TRY_REWRITE(x <= y || y < x, ctrue); - TVM_TRY_REWRITE(y < x || x <= y, ctrue); - - TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); - TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); - - TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); - TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); - return ret; + return RewriteBooleanOperators(ret); } PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) { @@ -1746,5 +1668,164 @@ RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) RewriteSimplifier::~RewriteSimplifier() { delete impl_; } +namespace { +/* Utility for rewriting only boolean portions of an expression + * + * Intended for application on an expression that has previously been + * simplified, but has subsequent manipulations performed. + * (e.g. Finding the simplified negation of a conditional without + * performing a full simplification.) + */ +class BooleanRewriter : public ExprMutator { + private: + PrimExpr VisitExpr(const PrimExpr& expr) override { + if (expr.dtype().is_bool()) { + return ExprMutator::VisitExpr(expr); + } else { + return expr; + } + } + + PrimExpr VisitExpr_(const NotNode* op) override { + PrimExpr ret = GetRef(op); + + if (auto const_res = TryConstFold(op->a)) return const_res.value(); + + PVar x, y; + PVar lanes; + if (op->dtype.lanes() != 1) { + TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes)); + } + + TVM_TRY_RECURSIVE_REWRITE(!(!x), x); + TVM_TRY_RECURSIVE_REWRITE(!(x || y), (!x) && (!y)); + TVM_TRY_RECURSIVE_REWRITE(!(x && y), (!x) || (!y)); + TVM_TRY_REWRITE(!(x <= y), y < x); + TVM_TRY_REWRITE(!(x >= y), x < y); + TVM_TRY_REWRITE(!(x < y), y <= x); + TVM_TRY_REWRITE(!(x > y), x <= y); + TVM_TRY_REWRITE(!(x == y), x != y); + TVM_TRY_REWRITE(!(x != y), x == y); + + return ret; + } + + PrimExpr VisitExpr_(const AndNode* op) override { + And ret = GetRef(op); + if (allow_recursion_) { + allow_recursion_ = false; + PrimExpr a = VisitExpr(ret->a); + PrimExpr b = VisitExpr(ret->b); + bool is_same = a.same_as(ret->a) && b.same_as(b); + if (!is_same) { + ret = And(a, b); + } + allow_recursion_ = true; + } + op = ret.get(); + + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + + // Pattern var to match any expression + PVar x, y; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + if (op->dtype.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); + } + + auto cfalse = PConst(make_const(op->dtype, false)); + + TVM_TRY_REWRITE(x == y && x != y, cfalse); + TVM_TRY_REWRITE(x != y && x == y, cfalse); + TVM_TRY_REWRITE(x && !x, cfalse); + TVM_TRY_REWRITE(x <= y && y < x, cfalse); + TVM_TRY_REWRITE(y < x && x <= y, cfalse); + + TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value); + + TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); + TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); + + return std::move(ret); + } + + PrimExpr VisitExpr_(const OrNode* op) override { + Or ret = GetRef(op); + if (allow_recursion_) { + allow_recursion_ = false; + PrimExpr a = VisitExpr(ret->a); + PrimExpr b = VisitExpr(ret->b); + bool is_same = a.same_as(ret->a) && b.same_as(b); + if (!is_same) { + ret = Or(a, b); + } + allow_recursion_ = true; + } + op = ret.get(); + + if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + + // Pattern var to match any expression + PVar x, y; + // Pattern var match IntImm + PVar c1, c2; + PVar lanes; + + if (op->dtype.lanes() != 1) { + TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); + } + + auto ctrue = PConst(make_const(op->dtype, true)); + + TVM_TRY_REWRITE(x == y || x != y, ctrue); + TVM_TRY_REWRITE(x != y || x == y, ctrue); + TVM_TRY_REWRITE(x || !x, ctrue); + TVM_TRY_REWRITE(x <= y || y < x, ctrue); + TVM_TRY_REWRITE(y < x || x <= y, ctrue); + + TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + + TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); + TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); + + return std::move(ret); + } + + PrimExpr RecursiveRewrite(const PrimExpr& expr) { + bool cache = true; + std::swap(cache, allow_recursion_); + auto output = VisitExpr(expr); + std::swap(cache, allow_recursion_); + return output; + } + + private: + bool allow_recursion_{false}; +}; +} // namespace + +PrimExpr RewriteBooleanOperators(const PrimExpr& expr) { return BooleanRewriter()(expr); } + } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 6007b6416742..08cfea646eb4 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -149,6 +149,20 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { } }; +/* Utility for rewriting only boolean portions of an expression + * + * Intended for application on an expression that has previously been + * simplified, but has subsequent manipulations performed. + * (e.g. Finding the simplified negation of a conditional without + * performing a full simplification.) Only a single simplication step + * is performed. + * + * \param expr The boolean expression to be simplified + * + * \returns The simplified boolean expression + */ +PrimExpr RewriteBooleanOperators(const PrimExpr& expr); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_REWRITE_SIMPLIFY_H_